Skip to content

Commit

Permalink
DB: Add env to skip DB creationˆ
Browse files Browse the repository at this point in the history
  • Loading branch information
lkaybob committed Dec 3, 2023
1 parent f4c8861 commit 817318a
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 10 deletions.
2 changes: 2 additions & 0 deletions pkg/db/v1beta1/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,6 @@ const (
DefaultPostgreSQLDatabase = "katib"
DefaultPostgreSQLHost = "katib-postgres"
DefaultPostgreSQLPort = "5432"

SkipDbMigrationEnvName = "SKIP_DB_MIGRATION"
)
22 changes: 17 additions & 5 deletions pkg/db/v1beta1/mysql/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,34 @@ package mysql

import (
"fmt"

"github.com/kubeflow/katib/pkg/db/v1beta1/common"
"github.com/kubeflow/katib/pkg/util/v1beta1/env"
"k8s.io/klog"
)

func (d *dbConn) DBInit() {
db := d.db
klog.Info("Initializing v1beta1 DB schema")
skipDbMigration := env.GetBoolEnvOrDefault(common.SkipDbMigrationEnvName, false)

if !skipDbMigration {
klog.Info("Initializing v1beta1 DB schema")

_, err := db.Exec(`CREATE TABLE IF NOT EXISTS observation_logs
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS observation_logs
(trial_name VARCHAR(255) NOT NULL,
id INT AUTO_INCREMENT PRIMARY KEY,
time DATETIME(6),
metric_name VARCHAR(255) NOT NULL,
value TEXT NOT NULL)`)
if err != nil {
klog.Fatalf("Error creating observation_logs table: %v", err)
if err != nil {
klog.Fatalf("Error creating observation_logs table: %v", err)
}
} else {
klog.Info("Skipping v1beta1 DB schema initialization.")

_, err := db.Query(`SELECT trial_name, id, time, metric_name, value FROM observation_logs LIMIT 1`)
if err != nil {
klog.Fatalf("Error validating observation_logs table: %v", err)
}
}
}

Expand Down
21 changes: 17 additions & 4 deletions pkg/db/v1beta1/postgres/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,35 @@ package postgres

import (
"fmt"
"github.com/kubeflow/katib/pkg/db/v1beta1/common"
"github.com/kubeflow/katib/pkg/util/v1beta1/env"

"k8s.io/klog"
)

func (d *dbConn) DBInit() {
db := d.db
klog.Info("Initializing v1beta1 DB schema")
skipDbMigration := env.GetBoolEnvOrDefault(common.SkipDbMigrationEnvName, false)

_, err := db.Exec(`CREATE TABLE IF NOT EXISTS observation_logs
if !skipDbMigration {
klog.Info("Initializing v1beta1 DB schema")

_, err := db.Exec(`CREATE TABLE IF NOT EXISTS observation_logs
(trial_name VARCHAR(255) NOT NULL,
id serial PRIMARY KEY,
time TIMESTAMP(6),
metric_name VARCHAR(255) NOT NULL,
value TEXT NOT NULL)`)
if err != nil {
klog.Fatalf("Error creating observation_logs table: %v", err)
if err != nil {
klog.Fatalf("Error creating observation_logs table: %v", err)
}
} else {
klog.Info("Skipping v1beta1 DB schema initialization.")

_, err := db.Query(`SELECT trial_name, id, time, metric_name, value FROM observation_logs LIMIT 1`)
if err != nil {
klog.Fatalf("Error validating observation_logs table: %v", err)
}
}
}

Expand Down
17 changes: 16 additions & 1 deletion pkg/util/v1beta1/env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,26 @@ limitations under the License.

package env

import "os"
import (
"k8s.io/klog"
"os"
"strconv"
)

func GetEnvOrDefault(key string, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}

func GetBoolEnvOrDefault(key string, fallback bool) bool {
if value, ok := os.LookupEnv(key); ok {
parsedValue, err := strconv.ParseBool(value)
if err != nil {
klog.Fatalf("Failed converting %s env to bool", key)
}
return parsedValue
}
return fallback
}
18 changes: 18 additions & 0 deletions pkg/util/v1beta1/env/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package env

import (
"fmt"
"os"
"testing"
)
Expand All @@ -35,3 +36,20 @@ func TestGetEnvWithDefault(t *testing.T) {
t.Errorf("Expected %s, got %s", expected, v)
}
}

func TestGetBoolEnvWithDefault(t *testing.T) {
expected := false
key := "TEST"
v := GetBoolEnvOrDefault(key, expected)
if v != expected {
t.Errorf("Expected %t, got %t", expected, v)
}

expected = true
envValue := fmt.Sprintf("%t", expected)
os.Setenv(key, envValue)
v = GetBoolEnvOrDefault(key, false)
if v != expected {
t.Errorf("Expected %t, got %t", expected, v)
}
}

0 comments on commit 817318a

Please sign in to comment.