Skip to content

Commit

Permalink
refactor(db): makes db options user configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
pallabpain committed Jan 17, 2024
1 parent 9274087 commit 447eb2f
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 120 deletions.
6 changes: 6 additions & 0 deletions config-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,17 @@ db_path: /var/lib/headscale/db.sqlite
# db_name: headscale
# db_user: foo
# db_pass: bar
# db_max_idle_conns: 5
# db_max_open_conns: 5
# db_conn_max_idle_time_secs

# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need
# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1.
# db_ssl: false

# Enables debug mode logging of DB operations
# db_debug: false

### TLS configuration
#
## Let's encrypt / ACME
Expand Down
111 changes: 57 additions & 54 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
"io"
"net"
"net/http"
_ "net/http/pprof" //nolint
_ "net/http/pprof" // nolint
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
Expand All @@ -22,15 +21,6 @@ import (
"github.com/gorilla/mux"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/derp"
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog"
"github.com/prometheus/client_golang/prometheus/promhttp"
Expand All @@ -52,6 +42,16 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/key"

"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/derp"
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)

var (
Expand All @@ -78,9 +78,6 @@ const (
type Headscale struct {
cfg *types.Config
db *db.HSDatabase
dbString string
dbType string
dbDebug bool
noisePrivateKey *key.MachinePrivate

DERPMap *tailcfg.DERPMap
Expand Down Expand Up @@ -116,59 +113,27 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}

var dbString string
switch cfg.DBtype {
case db.Postgres:
dbString = fmt.Sprintf(
"host=%s dbname=%s user=%s",
cfg.DBhost,
cfg.DBname,
cfg.DBuser,
)

if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil {
if !sslEnabled {
dbString += " sslmode=disable"
}
} else {
dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl)
}

if cfg.DBport != 0 {
dbString += fmt.Sprintf(" port=%d", cfg.DBport)
}

if cfg.DBpass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
}
case db.Sqlite:
dbString = cfg.DBpath
default:
return nil, errUnsupportedDatabase
}

registrationCache := cache.New(
registerCacheExpiration,
registerCacheCleanup,
)

app := Headscale{
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(),
}

database, err := db.NewHeadscaleDatabase(
cfg.DBtype,
dbString,
app.dbDebug,
app.nodeNotifier,
cfg.IPPrefixes,
cfg.BaseDomain)
dbConfig, err := buildHeadscaleDBConfig(cfg)
if err != nil {
return nil, err
}

dbConfig.NodeNotifier = app.nodeNotifier

database, err := db.NewHeadscaleDatabase(dbConfig)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -944,3 +909,41 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {

return &machineKey, nil
}

// buildHeadscaleDBConfig is helper function for creating the
// db config for headscale database.
func buildHeadscaleDBConfig(cfg *types.Config) (db.Config, error) {
dbConfig := db.NewDefaultConfig()

switch cfg.DBtype {
case db.Postgres:
dbConfig.DBType = db.Postgres
dbConfig.DBHost = cfg.DBhost
dbConfig.DBPort = cfg.DBport
dbConfig.DBName = cfg.DBname
dbConfig.DBSsl = cfg.DBssl
case db.Sqlite:
dbConfig.DBType = db.Sqlite
dbConfig.DBPath = cfg.DBpath
default:
return db.Config{}, errUnsupportedDatabase
}

dbConfig.BaseDomain = cfg.BaseDomain
dbConfig.IPPrefixes = cfg.IPPrefixes
dbConfig.DebugMode = cfg.DBdebug

if cfg.DBconnMaxIdleTimeSecs != 0 {
dbConfig.ConnectionMaxIdleTime = time.Duration(cfg.DBconnMaxIdleTimeSecs) * time.Second
}

if cfg.DBmaxOpenConns != 0 {
dbConfig.MaxOpenConnections = cfg.DBmaxOpenConns
}

if cfg.DBmaxIdleConns != 0 {
dbConfig.MaxIdleConnections = cfg.DBmaxIdleConns
}

return dbConfig, nil
}
118 changes: 94 additions & 24 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net/netip"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -48,16 +49,81 @@ type HSDatabase struct {
baseDomain string
}

// TODO(kradalby): assemble this struct from toptions or something typed
// rather than arguments.
func NewHeadscaleDatabase(
dbType, connectionAddr string,
debug bool,
notifier *notifier.Notifier,
ipPrefixes []netip.Prefix,
baseDomain string,
) (*HSDatabase, error) {
dbConn, err := openDB(dbType, connectionAddr, debug)
// cfg.DBtype,
// dbString,
// app.dbDebug,
// app.nodeNotifier,
// cfg.IPPrefixes,
// cfg.BaseDomain,

type Config struct {
DBType string
DBHost string
DBName string
DBUser string
DBPass string
DBPort int
DBPath string
DBSsl string

// Additional DB Settings
DebugMode bool
MaxOpenConnections int
MaxIdleConnections int
ConnectionMaxIdleTime time.Duration

IPPrefixes []netip.Prefix
BaseDomain string
NodeNotifier *notifier.Notifier
}

func (c Config) GetConnectionString() string {
var connStr string

switch c.DBType {
case Postgres:
connStr = fmt.Sprintf(
"host=%s dbname=%s user=%s",
c.DBHost,
c.DBName,
c.DBUser,
)
if sslMode, err := strconv.ParseBool(c.DBSsl); err == nil {
if !sslMode {
connStr += " sslmode=disable"
}
} else {
connStr += fmt.Sprintf(" sslmode=%s", c.DBSsl)
}

if c.DBPort != 0 {
connStr += fmt.Sprintf(" port=%d", c.DBPort)
}

if c.DBPass != "" {
connStr += fmt.Sprintf(" password=%s", c.DBPass)
}
case Sqlite:
connStr = c.DBPath
}

return connStr
}

// NewDefaultConfig returns default DB config.
func NewDefaultConfig() Config {
return Config{
DebugMode: false,
MaxOpenConnections: 5,
MaxIdleConnections: 5,
ConnectionMaxIdleTime: time.Hour,
}
}

// NewHeadscaleDatabase accepts a db.Config and returns an
// instance of HSDatabase.
func NewHeadscaleDatabase(config Config) (*HSDatabase, error) {
dbConn, err := openDB(config)
if err != nil {
return nil, err
}
Expand All @@ -71,7 +137,7 @@ func NewHeadscaleDatabase(
{
ID: "202312101416",
Migrate: func(tx *gorm.DB) error {
if dbType == Postgres {
if config.DBType == Postgres {
tx.Exec(`create extension if not exists "uuid-ossp";`)
}

Expand Down Expand Up @@ -188,7 +254,8 @@ func NewHeadscaleDatabase(
EnabledRoutes types.IPPrefixes
}

nodesAux := []NodeAux{}
var nodesAux []NodeAux

err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error
if err != nil {
log.Fatal().Err(err).Msg("Error accessing db")
Expand Down Expand Up @@ -319,26 +386,28 @@ func NewHeadscaleDatabase(

db := HSDatabase{
db: dbConn,
notifier: notifier,
notifier: config.NodeNotifier,

ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
ipPrefixes: config.IPPrefixes,
baseDomain: config.BaseDomain,
}

return &db, err
}

func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
func openDB(config Config) (*gorm.DB, error) {
connectionAddr := config.GetConnectionString()

log.Debug().Str("type", config.DBType).Str("connection", connectionAddr).Msg("opening database")

var dbLogger logger.Interface
if debug {
if config.DebugMode {
dbLogger = logger.Default
} else {
dbLogger = logger.Default.LogMode(logger.Silent)
}

switch dbType {
switch config.DBType {
case Sqlite:
db, err := gorm.Open(
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
Expand All @@ -352,7 +421,8 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {

// The pure Go SQLite library does not handle locking in
// the same way as the C based one and we cant use the gorm
// connection pool as of 2022/02/23.
// connection pool as of 2022/02/23. Also, not setting these
// via the values in the config.
sqlDB, _ := db.DB()
sqlDB.SetMaxIdleConns(1)
sqlDB.SetMaxOpenConns(1)
Expand All @@ -367,16 +437,16 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
})

sqlDB, _ := db.DB()
sqlDB.SetMaxOpenConns(10)
sqlDB.SetMaxIdleConns(10)
sqlDB.SetConnMaxIdleTime(time.Hour)
sqlDB.SetMaxOpenConns(config.MaxOpenConnections)
sqlDB.SetMaxIdleConns(config.MaxIdleConnections)
sqlDB.SetConnMaxIdleTime(config.ConnectionMaxIdleTime)

return db, err
}

return nil, fmt.Errorf(
"database of type %s is not supported: %w",
dbType,
config.DBType,
errDatabaseNotSupported,
)
}
Expand Down
Loading

0 comments on commit 447eb2f

Please sign in to comment.