From 551c460c71360c9b458be6351d2e5fd3fcf13181 Mon Sep 17 00:00:00 2001 From: Pallab Pain Date: Wed, 17 Jan 2024 15:23:39 +0530 Subject: [PATCH] refactor(db): makes db options user configurable --- config-example.yaml | 6 ++ hscontrol/app.go | 111 ++++++++++++++++----------------- hscontrol/db/db.go | 118 ++++++++++++++++++++++++++++-------- hscontrol/db/routes_test.go | 26 ++++---- hscontrol/db/suite_test.go | 22 +++---- hscontrol/types/config.go | 45 ++++++++------ 6 files changed, 208 insertions(+), 120 deletions(-) diff --git a/config-example.yaml b/config-example.yaml index 5105dcd8293..dcb0b49fc13 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -142,11 +142,17 @@ db_path: /var/lib/headscale/db.sqlite # db_name: headscale # db_user: foo # db_pass: bar +# db_max_idle_conns: 5 +# db_max_open_conns: 5 +# db_conn_max_idle_time_secs: 3600 # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. # db_ssl: false +# Enables debug mode logging of DB operations +# db_debug: false + ### TLS configuration # ## Let's encrypt / ACME diff --git a/hscontrol/app.go b/hscontrol/app.go index 5327d6f8257..e35b20ecccd 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -8,11 +8,10 @@ import ( "io" "net" "net/http" - _ "net/http/pprof" //nolint + _ "net/http/pprof" // nolint "os" "os/signal" "runtime" - "strconv" "strings" "sync" "syscall" @@ -22,15 +21,6 @@ import ( "github.com/gorilla/mux" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware" grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/juanfont/headscale" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/derp" - derpServer "github.com/juanfont/headscale/hscontrol/derp/server" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -52,6 +42,16 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/dnstype" "tailscale.com/types/key" + + "github.com/juanfont/headscale" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/derp" + derpServer "github.com/juanfont/headscale/hscontrol/derp/server" + "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" ) var ( @@ -78,9 +78,6 @@ const ( type Headscale struct { cfg *types.Config db *db.HSDatabase - dbString string - dbType string - dbDebug bool noisePrivateKey *key.MachinePrivate DERPMap *tailcfg.DERPMap @@ -116,37 +113,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) } - var dbString string - switch cfg.DBtype { - case db.Postgres: - dbString = fmt.Sprintf( - "host=%s dbname=%s user=%s", - cfg.DBhost, - cfg.DBname, - cfg.DBuser, - ) - - if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil { - if !sslEnabled { - dbString += " sslmode=disable" - } - } else { - dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl) - } - - if cfg.DBport != 0 { - dbString += fmt.Sprintf(" port=%d", cfg.DBport) - } - - if cfg.DBpass != "" { - dbString += fmt.Sprintf(" password=%s", cfg.DBpass) - } - case db.Sqlite: - dbString = cfg.DBpath - default: - return nil, errUnsupportedDatabase - } - registrationCache := cache.New( registerCacheExpiration, registerCacheCleanup, @@ -154,21 +120,20 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app := Headscale{ cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, noisePrivateKey: noisePrivateKey, registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, nodeNotifier: notifier.NewNotifier(), } - database, err := db.NewHeadscaleDatabase( - cfg.DBtype, - dbString, - app.dbDebug, - app.nodeNotifier, - cfg.IPPrefixes, - cfg.BaseDomain) + dbConfig, err := buildHeadscaleDBConfig(cfg) + if err != nil { + return nil, err + } + + dbConfig.NodeNotifier = app.nodeNotifier + + database, err := db.NewHeadscaleDatabase(dbConfig) if err != nil { return nil, err } @@ -944,3 +909,41 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } + +// buildHeadscaleDBConfig is helper function for creating the +// db config for headscale database. +func buildHeadscaleDBConfig(cfg *types.Config) (db.Config, error) { + dbConfig := db.NewDefaultConfig() + + switch cfg.DBtype { + case db.Postgres: + dbConfig.DBType = db.Postgres + dbConfig.DBHost = cfg.DBhost + dbConfig.DBPort = cfg.DBport + dbConfig.DBName = cfg.DBname + dbConfig.DBSsl = cfg.DBssl + case db.Sqlite: + dbConfig.DBType = db.Sqlite + dbConfig.DBPath = cfg.DBpath + default: + return db.Config{}, errUnsupportedDatabase + } + + dbConfig.BaseDomain = cfg.BaseDomain + dbConfig.IPPrefixes = cfg.IPPrefixes + dbConfig.DebugMode = cfg.DBdebug + + if cfg.DBconnMaxIdleTimeSecs != 0 { + dbConfig.ConnectionMaxIdleTime = time.Duration(cfg.DBconnMaxIdleTimeSecs) * time.Second + } + + if cfg.DBmaxOpenConns != 0 { + dbConfig.MaxOpenConnections = cfg.DBmaxOpenConns + } + + if cfg.DBmaxIdleConns != 0 { + dbConfig.MaxIdleConnections = cfg.DBmaxIdleConns + } + + return dbConfig, nil +} diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 828484d5188..7c5f1f3a068 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/netip" + "strconv" "strings" "sync" "time" @@ -48,16 +49,81 @@ type HSDatabase struct { baseDomain string } -// TODO(kradalby): assemble this struct from toptions or something typed -// rather than arguments. -func NewHeadscaleDatabase( - dbType, connectionAddr string, - debug bool, - notifier *notifier.Notifier, - ipPrefixes []netip.Prefix, - baseDomain string, -) (*HSDatabase, error) { - dbConn, err := openDB(dbType, connectionAddr, debug) +// cfg.DBtype, +// dbString, +// app.dbDebug, +// app.nodeNotifier, +// cfg.IPPrefixes, +// cfg.BaseDomain, + +type Config struct { + DBType string + DBHost string + DBName string + DBUser string + DBPass string + DBPort int + DBPath string + DBSsl string + + // Additional DB Settings + DebugMode bool + MaxOpenConnections int + MaxIdleConnections int + ConnectionMaxIdleTime time.Duration + + IPPrefixes []netip.Prefix + BaseDomain string + NodeNotifier *notifier.Notifier +} + +func (c Config) GetConnectionString() string { + var connStr string + + switch c.DBType { + case Postgres: + connStr = fmt.Sprintf( + "host=%s dbname=%s user=%s", + c.DBHost, + c.DBName, + c.DBUser, + ) + if sslMode, err := strconv.ParseBool(c.DBSsl); err == nil { + if !sslMode { + connStr += " sslmode=disable" + } + } else { + connStr += fmt.Sprintf(" sslmode=%s", c.DBSsl) + } + + if c.DBPort != 0 { + connStr += fmt.Sprintf(" port=%d", c.DBPort) + } + + if c.DBPass != "" { + connStr += fmt.Sprintf(" password=%s", c.DBPass) + } + case Sqlite: + connStr = c.DBPath + } + + return connStr +} + +// NewDefaultConfig returns default DB config. +func NewDefaultConfig() Config { + return Config{ + DebugMode: false, + MaxOpenConnections: 5, + MaxIdleConnections: 5, + ConnectionMaxIdleTime: time.Hour, + } +} + +// NewHeadscaleDatabase accepts a db.Config and returns an +// instance of HSDatabase. +func NewHeadscaleDatabase(config Config) (*HSDatabase, error) { + dbConn, err := openDB(config) if err != nil { return nil, err } @@ -71,7 +137,7 @@ func NewHeadscaleDatabase( { ID: "202312101416", Migrate: func(tx *gorm.DB) error { - if dbType == Postgres { + if config.DBType == Postgres { tx.Exec(`create extension if not exists "uuid-ossp";`) } @@ -188,7 +254,8 @@ func NewHeadscaleDatabase( EnabledRoutes types.IPPrefixes } - nodesAux := []NodeAux{} + var nodesAux []NodeAux + err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error if err != nil { log.Fatal().Err(err).Msg("Error accessing db") @@ -319,26 +386,28 @@ func NewHeadscaleDatabase( db := HSDatabase{ db: dbConn, - notifier: notifier, + notifier: config.NodeNotifier, - ipPrefixes: ipPrefixes, - baseDomain: baseDomain, + ipPrefixes: config.IPPrefixes, + baseDomain: config.BaseDomain, } return &db, err } -func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { - log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") +func openDB(config Config) (*gorm.DB, error) { + connectionAddr := config.GetConnectionString() + + log.Debug().Str("type", config.DBType).Str("connection", connectionAddr).Msg("opening database") var dbLogger logger.Interface - if debug { + if config.DebugMode { dbLogger = logger.Default } else { dbLogger = logger.Default.LogMode(logger.Silent) } - switch dbType { + switch config.DBType { case Sqlite: db, err := gorm.Open( sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), @@ -352,7 +421,8 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { // The pure Go SQLite library does not handle locking in // the same way as the C based one and we cant use the gorm - // connection pool as of 2022/02/23. + // connection pool as of 2022/02/23. Also, not setting these + // via the values in the config. sqlDB, _ := db.DB() sqlDB.SetMaxIdleConns(1) sqlDB.SetMaxOpenConns(1) @@ -367,16 +437,16 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { }) sqlDB, _ := db.DB() - sqlDB.SetMaxOpenConns(10) - sqlDB.SetMaxIdleConns(10) - sqlDB.SetConnMaxIdleTime(time.Hour) + sqlDB.SetMaxOpenConns(config.MaxOpenConnections) + sqlDB.SetMaxIdleConns(config.MaxIdleConnections) + sqlDB.SetConnMaxIdleTime(config.ConnectionMaxIdleTime) return db, err } return nil, fmt.Errorf( "database of type %s is not supported: %w", - dbType, + config.DBType, errDatabaseNotSupported, ) } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index d491b6a306f..7cbc04ab5c0 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -7,14 +7,15 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "gopkg.in/check.v1" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" + + "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" ) func (s *Suite) TestGetRoutes(c *check.C) { @@ -585,16 +586,15 @@ func TestFailoverRoute(t *testing.T) { notif := notifier.NewNotifier() - db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, - notif, - []netip.Prefix{ - netip.MustParsePrefix("10.27.0.0/23"), - }, - "", - ) + config := NewDefaultConfig() + config.NodeNotifier = notif + config.DBType = Sqlite + config.DBPath = tmpDir + "/headscale_test.db" + config.IPPrefixes = []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + } + + db, err = NewHeadscaleDatabase(config) assert.NoError(t, err) // Pretend that all the nodes are connected to control diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 1c384918f7d..87a737d37ed 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -6,8 +6,9 @@ import ( "os" "testing" - "github.com/juanfont/headscale/hscontrol/notifier" "gopkg.in/check.v1" + + "github.com/juanfont/headscale/hscontrol/notifier" ) func Test(t *testing.T) { @@ -44,16 +45,15 @@ func (s *Suite) ResetDB(c *check.C) { log.Printf("database path: %s", tmpDir+"/headscale_test.db") - db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, - notifier.NewNotifier(), - []netip.Prefix{ - netip.MustParsePrefix("10.27.0.0/23"), - }, - "", - ) + config := NewDefaultConfig() + config.DBType = Sqlite + config.DBPath = tmpDir + "/headscale_test.db" + config.NodeNotifier = notifier.NewNotifier() + config.IPPrefixes = []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + } + + db, err = NewHeadscaleDatabase(config) if err != nil { c.Fatal(err) } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4b29c4b7007..6d348e4d8ae 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -11,7 +11,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -20,6 +19,8 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + + "github.com/juanfont/headscale/hscontrol/util" ) const ( @@ -48,14 +49,18 @@ type Config struct { DERP DERPConfig - DBtype string - DBpath string - DBhost string - DBport int - DBname string - DBuser string - DBpass string - DBssl string + DBtype string + DBpath string + DBhost string + DBport int + DBname string + DBuser string + DBpass string + DBssl string + DBdebug bool + DBmaxIdleConns int + DBmaxOpenConns int + DBconnMaxIdleTimeSecs int TLS TLSConfig @@ -254,7 +259,7 @@ func LoadConfig(path string, isFile bool) error { } if errorText != "" { - //nolint + // nolint return errors.New(strings.TrimSuffix(errorText, "\n")) } else { return nil @@ -599,14 +604,18 @@ func GetHeadscaleConfig() (*Config, error) { "node_update_check_interval", ), - DBtype: viper.GetString("db_type"), - DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")), - DBhost: viper.GetString("db_host"), - DBport: viper.GetInt("db_port"), - DBname: viper.GetString("db_name"), - DBuser: viper.GetString("db_user"), - DBpass: viper.GetString("db_pass"), - DBssl: viper.GetString("db_ssl"), + DBtype: viper.GetString("db_type"), + DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")), + DBhost: viper.GetString("db_host"), + DBport: viper.GetInt("db_port"), + DBname: viper.GetString("db_name"), + DBuser: viper.GetString("db_user"), + DBpass: viper.GetString("db_pass"), + DBssl: viper.GetString("db_ssl"), + DBdebug: viper.GetBool("db_debug"), + DBmaxIdleConns: viper.GetInt("db_max_idle_conns"), + DBmaxOpenConns: viper.GetInt("db_max_open_conns"), + DBconnMaxIdleTimeSecs: viper.GetInt("db_conn_max_idle_time_secs"), TLS: GetTLSConfig(),