From 4461499de451ddd72c5996fcafe698e84bf3a518 Mon Sep 17 00:00:00 2001 From: kevin Date: Mon, 27 Jan 2025 20:28:02 +0800 Subject: [PATCH] feat: auto validate config --- core/conf/config.go | 13 +++++- core/conf/config_test.go | 77 +++++++++++++++++++++++++++++------- core/conf/validate.go | 12 ++++++ core/conf/validate_test.go | 81 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 167 insertions(+), 16 deletions(-) create mode 100644 core/conf/validate.go create mode 100644 core/conf/validate_test.go diff --git a/core/conf/config.go b/core/conf/config.go index 743f8c6144e1..6428abd7da95 100644 --- a/core/conf/config.go +++ b/core/conf/config.go @@ -62,7 +62,11 @@ func Load(file string, v any, opts ...Option) error { return loader([]byte(os.ExpandEnv(string(content))), v) } - return loader(content, v) + if err = loader(content, v); err != nil { + return err + } + + return validate(v) } // LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable. @@ -85,7 +89,12 @@ func LoadFromJsonBytes(content []byte, v any) error { lowerCaseKeyMap := toLowerCaseKeyMap(m, info) - return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase)) + if err = mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, + mapping.WithCanonicalKeyFunc(toLowerCase)); err != nil { + return err + } + + return validate(v) } // LoadConfigFromJsonBytes loads config into v from content json bytes. diff --git a/core/conf/config_test.go b/core/conf/config_test.go index e3025fba89eb..11b05294c492 100644 --- a/core/conf/config_test.go +++ b/core/conf/config_test.go @@ -1,6 +1,7 @@ package conf import ( + "errors" "os" "reflect" "testing" @@ -40,9 +41,8 @@ func TestConfigJson(t *testing.T) { for _, test := range tests { test := test t.Run(test, func(t *testing.T) { - tmpfile, err := createTempFile(test, text) + tmpfile, err := createTempFile(t, test, text) assert.Nil(t, err) - defer os.Remove(tmpfile) var val struct { A string `json:"a"` @@ -82,9 +82,8 @@ c = "${FOO}" d = "abcd!@#$112" ` t.Setenv("FOO", "2") - tmpfile, err := createTempFile(".toml", text) + tmpfile, err := createTempFile(t, ".toml", text) assert.Nil(t, err) - defer os.Remove(tmpfile) var val struct { A string `json:"a"` @@ -105,9 +104,8 @@ b = 1 c = "FOO" d = "abcd" ` - tmpfile, err := createTempFile(".toml", text) + tmpfile, err := createTempFile(t, ".toml", text) assert.Nil(t, err) - defer os.Remove(tmpfile) var val struct { A string `json:"a"` @@ -127,9 +125,8 @@ func TestConfigWithLower(t *testing.T) { text := `a = "foo" b = 1 ` - tmpfile, err := createTempFile(".toml", text) + tmpfile, err := createTempFile(t, ".toml", text) assert.Nil(t, err) - defer os.Remove(tmpfile) var val struct { A string `json:"a"` @@ -207,9 +204,8 @@ c = "${FOO}" d = "abcd!@#112" ` t.Setenv("FOO", "2") - tmpfile, err := createTempFile(".toml", text) + tmpfile, err := createTempFile(t, ".toml", text) assert.Nil(t, err) - defer os.Remove(tmpfile) var val struct { A string `json:"a"` @@ -241,9 +237,8 @@ func TestConfigJsonEnv(t *testing.T) { for _, test := range tests { test := test t.Run(test, func(t *testing.T) { - tmpfile, err := createTempFile(test, text) + tmpfile, err := createTempFile(t, test, text) assert.Nil(t, err) - defer os.Remove(tmpfile) var val struct { A string `json:"a"` @@ -1217,11 +1212,44 @@ Name = "bar" }) } +func Test_LoadBadConfig(t *testing.T) { + type Config struct { + Name string `json:"name,options=foo|bar"` + } + + file, err := createTempFile(t, ".json", `{"name": "baz"}`) + assert.NoError(t, err) + + var c Config + err = Load(file, &c) + assert.Error(t, err) +} + func Test_getFullName(t *testing.T) { assert.Equal(t, "a.b", getFullName("a", "b")) assert.Equal(t, "a", getFullName("", "a")) } +func TestValidate(t *testing.T) { + t.Run("normal config", func(t *testing.T) { + var c mockConfig + err := LoadFromJsonBytes([]byte(`{"val": "hello", "number": 8}`), &c) + assert.NoError(t, err) + }) + + t.Run("error no int", func(t *testing.T) { + var c mockConfig + err := LoadFromJsonBytes([]byte(`{"val": "hello"}`), &c) + assert.Error(t, err) + }) + + t.Run("error no string", func(t *testing.T) { + var c mockConfig + err := LoadFromJsonBytes([]byte(`{"number": 8}`), &c) + assert.Error(t, err) + }) +} + func Test_buildFieldsInfo(t *testing.T) { type ParentSt struct { Name string @@ -1311,13 +1339,13 @@ func Test_buildFieldsInfo(t *testing.T) { } } -func createTempFile(ext, text string) (string, error) { +func createTempFile(t *testing.T, ext, text string) (string, error) { tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext) if err != nil { return "", err } - if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil { + if err = os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil { return "", err } @@ -1326,5 +1354,26 @@ func createTempFile(ext, text string) (string, error) { return "", err } + t.Cleanup(func() { + _ = os.Remove(filename) + }) + return filename, nil } + +type mockConfig struct { + Val string + Number int +} + +func (m mockConfig) Validate() error { + if len(m.Val) == 0 { + return errors.New("val is empty") + } + + if m.Number == 0 { + return errors.New("number is zero") + } + + return nil +} diff --git a/core/conf/validate.go b/core/conf/validate.go new file mode 100644 index 000000000000..2724aa6a2bd9 --- /dev/null +++ b/core/conf/validate.go @@ -0,0 +1,12 @@ +package conf + +import "github.com/zeromicro/go-zero/core/validation" + +// validate validates the value if it implements the Validator interface. +func validate(v any) error { + if val, ok := v.(validation.Validator); ok { + return val.Validate() + } + + return nil +} diff --git a/core/conf/validate_test.go b/core/conf/validate_test.go new file mode 100644 index 000000000000..6a445e6f7782 --- /dev/null +++ b/core/conf/validate_test.go @@ -0,0 +1,81 @@ +package conf + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockType int + +func (m mockType) Validate() error { + if m < 10 { + return errors.New("invalid value") + } + + return nil +} + +type anotherMockType int + +func Test_validate(t *testing.T) { + tests := []struct { + name string + v any + wantErr bool + }{ + { + name: "invalid", + v: mockType(5), + wantErr: true, + }, + { + name: "valid", + v: mockType(10), + wantErr: false, + }, + { + name: "not validator", + v: anotherMockType(5), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validate(tt.v) + assert.Equal(t, tt.wantErr, err != nil) + }) + } +} + +type mockVal struct { +} + +func (m mockVal) Validate() error { + return errors.New("invalid value") +} + +func Test_validateValPtr(t *testing.T) { + tests := []struct { + name string + v any + wantErr bool + }{ + { + name: "invalid", + v: mockVal{}, + }, + { + name: "invalid value", + v: &mockVal{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Error(t, validate(tt.v)) + }) + } +}