diff --git a/hscontrol/app.go b/hscontrol/app.go index 5327d6f8257..e35b20ecccd 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -8,11 +8,10 @@ import ( "io" "net" "net/http" - _ "net/http/pprof" //nolint + _ "net/http/pprof" // nolint "os" "os/signal" "runtime" - "strconv" "strings" "sync" "syscall" @@ -22,15 +21,6 @@ import ( "github.com/gorilla/mux" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware" grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/juanfont/headscale" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/derp" - derpServer "github.com/juanfont/headscale/hscontrol/derp/server" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -52,6 +42,16 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/dnstype" "tailscale.com/types/key" + + "github.com/juanfont/headscale" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/derp" + derpServer "github.com/juanfont/headscale/hscontrol/derp/server" + "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" ) var ( @@ -78,9 +78,6 @@ const ( type Headscale struct { cfg *types.Config db *db.HSDatabase - dbString string - dbType string - dbDebug bool noisePrivateKey *key.MachinePrivate DERPMap *tailcfg.DERPMap @@ -116,37 +113,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) } - var dbString string - switch cfg.DBtype { - case db.Postgres: - dbString = fmt.Sprintf( - "host=%s dbname=%s user=%s", - cfg.DBhost, - cfg.DBname, - cfg.DBuser, - ) - - if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil { - if !sslEnabled { - dbString += " sslmode=disable" - } - } else { - dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl) - } - - if cfg.DBport != 0 { - dbString += fmt.Sprintf(" port=%d", cfg.DBport) - } - - if cfg.DBpass != "" { - dbString += fmt.Sprintf(" password=%s", cfg.DBpass) - } - case db.Sqlite: - dbString = cfg.DBpath - default: - return nil, errUnsupportedDatabase - } - registrationCache := cache.New( registerCacheExpiration, registerCacheCleanup, @@ -154,21 +120,20 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app := Headscale{ cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, noisePrivateKey: noisePrivateKey, registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, nodeNotifier: notifier.NewNotifier(), } - database, err := db.NewHeadscaleDatabase( - cfg.DBtype, - dbString, - app.dbDebug, - app.nodeNotifier, - cfg.IPPrefixes, - cfg.BaseDomain) + dbConfig, err := buildHeadscaleDBConfig(cfg) + if err != nil { + return nil, err + } + + dbConfig.NodeNotifier = app.nodeNotifier + + database, err := db.NewHeadscaleDatabase(dbConfig) if err != nil { return nil, err } @@ -944,3 +909,41 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } + +// buildHeadscaleDBConfig is helper function for creating the +// db config for headscale database. +func buildHeadscaleDBConfig(cfg *types.Config) (db.Config, error) { + dbConfig := db.NewDefaultConfig() + + switch cfg.DBtype { + case db.Postgres: + dbConfig.DBType = db.Postgres + dbConfig.DBHost = cfg.DBhost + dbConfig.DBPort = cfg.DBport + dbConfig.DBName = cfg.DBname + dbConfig.DBSsl = cfg.DBssl + case db.Sqlite: + dbConfig.DBType = db.Sqlite + dbConfig.DBPath = cfg.DBpath + default: + return db.Config{}, errUnsupportedDatabase + } + + dbConfig.BaseDomain = cfg.BaseDomain + dbConfig.IPPrefixes = cfg.IPPrefixes + dbConfig.DebugMode = cfg.DBdebug + + if cfg.DBconnMaxIdleTimeSecs != 0 { + dbConfig.ConnectionMaxIdleTime = time.Duration(cfg.DBconnMaxIdleTimeSecs) * time.Second + } + + if cfg.DBmaxOpenConns != 0 { + dbConfig.MaxOpenConnections = cfg.DBmaxOpenConns + } + + if cfg.DBmaxIdleConns != 0 { + dbConfig.MaxIdleConnections = cfg.DBmaxIdleConns + } + + return dbConfig, nil +} diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 828484d5188..dc05e101389 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/netip" + "strconv" "strings" "sync" "time" @@ -48,16 +49,81 @@ type HSDatabase struct { baseDomain string } -// TODO(kradalby): assemble this struct from toptions or something typed -// rather than arguments. -func NewHeadscaleDatabase( - dbType, connectionAddr string, - debug bool, - notifier *notifier.Notifier, - ipPrefixes []netip.Prefix, - baseDomain string, -) (*HSDatabase, error) { - dbConn, err := openDB(dbType, connectionAddr, debug) +// cfg.DBtype, +// dbString, +// app.dbDebug, +// app.nodeNotifier, +// cfg.IPPrefixes, +// cfg.BaseDomain, + +type Config struct { + DBType string + DBHost string + DBName string + DBUser string + DBPass string + DBPort int + DBPath string + DBSsl string + + // Additional DB Settings + DebugMode bool + MaxOpenConnections int + MaxIdleConnections int + ConnectionMaxIdleTime time.Duration + + IPPrefixes []netip.Prefix + BaseDomain string + NodeNotifier *notifier.Notifier +} + +func (c Config) GetConnectionString() string { + var connStr string + + switch c.DBType { + case Postgres: + connStr = fmt.Sprintf( + "host=%s dbname=%s user=%s", + c.DBHost, + c.DBName, + c.DBUser, + ) + if sslMode, err := strconv.ParseBool(c.DBSsl); err == nil { + if !sslMode { + connStr += " sslmode=disable" + } + } else { + connStr += fmt.Sprintf(" sslmode=%s", c.DBSsl) + } + + if c.DBPort != 0 { + connStr += fmt.Sprintf(" port=%d", c.DBPort) + } + + if c.DBPass != "" { + connStr += fmt.Sprintf(" password=%s", c.DBPass) + } + case Sqlite: + connStr = c.DBPath + } + + return connStr +} + +// NewDefaultConfig returns default DB config. +func NewDefaultConfig() Config { + return Config{ + DebugMode: false, + MaxOpenConnections: 1, + MaxIdleConnections: 1, + ConnectionMaxIdleTime: time.Hour, + } +} + +// NewHeadscaleDatabase accepts a db.Config and returns an +// instance of HSDatabase. +func NewHeadscaleDatabase(config Config) (*HSDatabase, error) { + dbConn, err := openDB(config) if err != nil { return nil, err } @@ -71,7 +137,7 @@ func NewHeadscaleDatabase( { ID: "202312101416", Migrate: func(tx *gorm.DB) error { - if dbType == Postgres { + if config.DBType == Postgres { tx.Exec(`create extension if not exists "uuid-ossp";`) } @@ -188,7 +254,8 @@ func NewHeadscaleDatabase( EnabledRoutes types.IPPrefixes } - nodesAux := []NodeAux{} + var nodesAux []NodeAux + err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error if err != nil { log.Fatal().Err(err).Msg("Error accessing db") @@ -319,26 +386,28 @@ func NewHeadscaleDatabase( db := HSDatabase{ db: dbConn, - notifier: notifier, + notifier: config.NodeNotifier, - ipPrefixes: ipPrefixes, - baseDomain: baseDomain, + ipPrefixes: config.IPPrefixes, + baseDomain: config.BaseDomain, } return &db, err } -func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { - log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") +func openDB(config Config) (*gorm.DB, error) { + connectionAddr := config.GetConnectionString() + + log.Debug().Str("type", config.DBType).Str("connection", connectionAddr).Msg("opening database") var dbLogger logger.Interface - if debug { + if config.DebugMode { dbLogger = logger.Default } else { dbLogger = logger.Default.LogMode(logger.Silent) } - switch dbType { + switch config.DBType { case Sqlite: db, err := gorm.Open( sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), @@ -367,16 +436,16 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { }) sqlDB, _ := db.DB() - sqlDB.SetMaxOpenConns(10) - sqlDB.SetMaxIdleConns(10) - sqlDB.SetConnMaxIdleTime(time.Hour) + sqlDB.SetMaxOpenConns(config.MaxOpenConnections) + sqlDB.SetMaxIdleConns(config.MaxIdleConnections) + sqlDB.SetConnMaxIdleTime(config.ConnectionMaxIdleTime) return db, err } return nil, fmt.Errorf( "database of type %s is not supported: %w", - dbType, + config.DBType, errDatabaseNotSupported, ) } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index d491b6a306f..7cbc04ab5c0 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -7,14 +7,15 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "gopkg.in/check.v1" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" + + "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" ) func (s *Suite) TestGetRoutes(c *check.C) { @@ -585,16 +586,15 @@ func TestFailoverRoute(t *testing.T) { notif := notifier.NewNotifier() - db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, - notif, - []netip.Prefix{ - netip.MustParsePrefix("10.27.0.0/23"), - }, - "", - ) + config := NewDefaultConfig() + config.NodeNotifier = notif + config.DBType = Sqlite + config.DBPath = tmpDir + "/headscale_test.db" + config.IPPrefixes = []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + } + + db, err = NewHeadscaleDatabase(config) assert.NoError(t, err) // Pretend that all the nodes are connected to control diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 1c384918f7d..87a737d37ed 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -6,8 +6,9 @@ import ( "os" "testing" - "github.com/juanfont/headscale/hscontrol/notifier" "gopkg.in/check.v1" + + "github.com/juanfont/headscale/hscontrol/notifier" ) func Test(t *testing.T) { @@ -44,16 +45,15 @@ func (s *Suite) ResetDB(c *check.C) { log.Printf("database path: %s", tmpDir+"/headscale_test.db") - db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, - notifier.NewNotifier(), - []netip.Prefix{ - netip.MustParsePrefix("10.27.0.0/23"), - }, - "", - ) + config := NewDefaultConfig() + config.DBType = Sqlite + config.DBPath = tmpDir + "/headscale_test.db" + config.NodeNotifier = notifier.NewNotifier() + config.IPPrefixes = []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + } + + db, err = NewHeadscaleDatabase(config) if err != nil { c.Fatal(err) } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4b29c4b7007..482fe48855e 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -11,7 +11,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -20,6 +19,8 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + + "github.com/juanfont/headscale/hscontrol/util" ) const ( @@ -48,14 +49,18 @@ type Config struct { DERP DERPConfig - DBtype string - DBpath string - DBhost string - DBport int - DBname string - DBuser string - DBpass string - DBssl string + DBtype string + DBpath string + DBhost string + DBport int + DBname string + DBuser string + DBpass string + DBssl string + DBdebug bool + DBmaxIdleConns int + DBmaxOpenConns int + DBconnMaxIdleTimeSecs int TLS TLSConfig @@ -254,7 +259,7 @@ func LoadConfig(path string, isFile bool) error { } if errorText != "" { - //nolint + // nolint return errors.New(strings.TrimSuffix(errorText, "\n")) } else { return nil