From d0abfb0e5374c809ae6f62b10fb4ee5d229ebe96 Mon Sep 17 00:00:00 2001 From: Patrik Date: Tue, 25 Oct 2022 14:37:10 +0200 Subject: [PATCH 1/3] feat: add sqlfields package --- .github/conventional_commits.json | 1 + go.sum | 11 --- sqlfields/base_types_test.go | 102 ++++++++++++++++++++++ sqlfields/builtin.go | 140 ++++++++++++++++++++++++++++++ sqlfields/json_raw_message.go | 38 ++++++++ sqlfields/nullable.go | 99 +++++++++++++++++++++ sqlfields/nullable_test.go | 95 ++++++++++++++++++++ sqlfields/stringslice.go | 69 +++++++++++++++ sqlfields/time.go | 103 ++++++++++++++++++++++ sqlxx/types.go | 12 ++- 10 files changed, 657 insertions(+), 13 deletions(-) create mode 100644 sqlfields/base_types_test.go create mode 100644 sqlfields/builtin.go create mode 100644 sqlfields/json_raw_message.go create mode 100644 sqlfields/nullable.go create mode 100644 sqlfields/nullable_test.go create mode 100644 sqlfields/stringslice.go create mode 100644 sqlfields/time.go diff --git a/.github/conventional_commits.json b/.github/conventional_commits.json index dfa16f85..3b6117e5 100644 --- a/.github/conventional_commits.json +++ b/.github/conventional_commits.json @@ -53,6 +53,7 @@ "sjsonx", "snapshotx", "sqlcon", + "sqlfields", "sqlxx", "stringslice", "stringsx", diff --git a/go.sum b/go.sum index cc22d83a..4032f857 100644 --- a/go.sum +++ b/go.sum @@ -190,7 +190,6 @@ github.com/bugsnag/panicwrap v0.0.0-20151223152923-e2c28503fcd0/go.mod h1:D/8v3k github.com/cenkalti/backoff/v4 v4.1.3 h1:cFAlzYUlVYDysBEH2T5hyJZMh3+5+WCBvSnK6Q8UtC4= github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= @@ -364,7 +363,6 @@ github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible github.com/docker/distribution v2.7.1+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= github.com/docker/distribution v2.8.1+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v20.10.7+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/docker v20.10.9+incompatible h1:JlsVnETOjM2RLQa0Cc1XCIspUdXW3Zenq9P54uXBm6k= github.com/docker/docker v20.10.9+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= @@ -697,7 +695,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-jsonnet v0.17.0 h1:/9NIEfhK1NQRKl3sP2536b2+x5HnZMdql7x3yK/l8JY= github.com/google/go-jsonnet v0.17.0/go.mod h1:sOcuej3UW1vpPTZOr8L7RQimqai1a57bt5j22LzGZCw= github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= @@ -757,7 +754,6 @@ github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 h1:BZHcxBETFHIdVyhyEfOvn/RdU/QGdLI4y34qQGjGWO0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= @@ -840,7 +836,6 @@ github.com/instana/go-sensor v1.41.1/go.mod h1:E42MelHWFz11qqaLwvgt0j98v2s2O/bq2 github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65 h1:T25FL3WEzgmKB0m6XCJNZ65nw09/QIp3T1yXr487D+A= github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65/go.mod h1:nYhEREG/B7HUY7P+LKOrqy53TpIqmJ9JyUShcaEKtGw= github.com/j-keck/arping v0.0.0-20160618110441-2cf9dc699c56/go.mod h1:ymszkNOg6tORTn+6F6j+Jc8TOr5osrynvN6ivFWZ2GA= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -867,7 +862,6 @@ github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5W github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= @@ -1003,7 +997,6 @@ github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -1404,7 +1397,6 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= @@ -1880,8 +1872,6 @@ golang.org/x/sys v0.0.0-20211107104306-e0b2ad06fe42/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220405210540-1e041c57c461/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a h1:N2T1jUrTQE9Re6TFF5PhvEHXHCguynGhKjWVsIUt5cY= golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= @@ -2173,7 +2163,6 @@ gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81 gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= gotest.tools/v3 v3.2.0 h1:I0DwBVMGAx26dttAj1BtJLAkVGncrkkUXfJLC4Flt/I= -gotest.tools/v3 v3.2.0/go.mod h1:Mcr9QNxkg0uMvy/YElmo4SpXgJKWgQvYrT7Kw5RzJ1A= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/sqlfields/base_types_test.go b/sqlfields/base_types_test.go new file mode 100644 index 00000000..e1a1460c --- /dev/null +++ b/sqlfields/base_types_test.go @@ -0,0 +1,102 @@ +package sqlfields + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testJSONEncoding[T any](t *testing.T, value T, expectedJSON string, isNullable bool) { + t.Run(fmt.Sprintf("type=%T", value), func(t *testing.T) { + t.Run("case=marshal", func(t *testing.T) { + actual, err := json.Marshal(value) + require.NoError(t, err) + assert.JSONEq(t, expectedJSON, string(actual)) + }) + t.Run("case=unmarshal", func(t *testing.T) { + var other T + require.NoError(t, json.Unmarshal([]byte(expectedJSON), &other)) + assert.EqualValues(t, value, other) + }) + t.Run("case=null", func(t *testing.T) { + var actual, expected T + assert.NoError(t, json.Unmarshal([]byte("null"), &actual)) + if _, ok := any(value).(JSONRawMessage); ok { + assert.Equal(t, JSONRawMessage("null"), actual) + } else { + assert.Equal(t, expected, actual) + } + + raw, err := json.Marshal(expected) + require.NoError(t, err) + if !isNullable { + assert.NotEqual(t, "null", string(raw)) + } else { + assert.Equal(t, "null", string(raw)) + } + }) + }) +} + +func TestJSONCompat(t *testing.T) { + testJSONEncoding(t, String("foo"), `"foo"`, false) + testJSONEncoding(t, Int(123), `123`, false) + testJSONEncoding(t, Int32(456), `456`, false) + testJSONEncoding(t, Int64(789), `789`, false) + testJSONEncoding(t, Float64(1.23), `1.23`, false) + testJSONEncoding(t, Bool(true), `true`, false) + testJSONEncoding(t, Duration(10*time.Second), `"10s"`, false) + testJSONEncoding(t, Time(time.Unix(123, 0).UTC()), `"1970-01-01T00:02:03Z"`, false) + testJSONEncoding(t, JSONRawMessage(`{"foo":"bar"}`), `{"foo":"bar"}`, true) + testJSONEncoding(t, StringSliceJSONFormat{"foo", "bar"}, `["foo","bar"]`, true) + testJSONEncoding(t, StringSlicePipeDelimiter{"foo", "bar"}, `["foo","bar"]`, true) +} + +func testSQLCompatibility[T any](t *testing.T, db *sqlx.DB, value T) { + insertValue := func(t *testing.T, value T) int64 { + res, err := db.Exec(`INSERT INTO "testing" ("value") VALUES (?)`, value) + require.NoError(t, err) + id, err := res.LastInsertId() + require.NoError(t, err) + return id + } + + t.Run(fmt.Sprintf("type=%T", value), func(t *testing.T) { + t.Run("case=insert and select", func(t *testing.T) { + var actual T + require.NoError(t, db.Get(&actual, `SELECT "value" FROM "testing" WHERE "id" = ?`, insertValue(t, value))) + assert.EqualValues(t, value, actual) + }) + }) +} + +func TestSQLCompat(t *testing.T) { + db, err := sqlx.Connect("sqlite3", "file::memory:") + require.NoError(t, err) + defer db.Close() + + // You have to hate the inconsistencies of SQLite. But for this test, it's great to have column that takes any data type. + _, err = db.Exec(`CREATE TABLE "testing" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "value" BLOB +)`) + require.NoError(t, err) + + testSQLCompatibility(t, db, String("foo")) + testSQLCompatibility(t, db, Int(123)) + testSQLCompatibility(t, db, Int32(456)) + testSQLCompatibility(t, db, Int64(789)) + testSQLCompatibility(t, db, Float64(1.23)) + testSQLCompatibility(t, db, Bool(true)) + testSQLCompatibility(t, db, Duration(10*time.Second)) + testSQLCompatibility(t, db, Time(time.Unix(12345, 0).UTC())) + testSQLCompatibility(t, db, JSONRawMessage(`{"foo":"bar"}`)) + testSQLCompatibility(t, db, StringSliceJSONFormat{"foo", "bar"}) + testSQLCompatibility(t, db, StringSlicePipeDelimiter{"foo", "bar"}) +} diff --git a/sqlfields/builtin.go b/sqlfields/builtin.go new file mode 100644 index 00000000..33a26cdc --- /dev/null +++ b/sqlfields/builtin.go @@ -0,0 +1,140 @@ +package sqlfields + +import ( + "database/sql/driver" + "math" + + "github.com/pkg/errors" +) + +func NewNullString(s string) NullString { + return NullString{Val: String(s), Valid: true} +} + +type String string + +func (s *String) Scan(value any) error { + switch v := value.(type) { + case string: + *s = String(v) + case []byte: + *s = String(v) + default: + return errors.Errorf("String.Scan: cannot scan type %T into String", value) + } + return nil +} + +func (s *String) Value() (driver.Value, error) { + return string(*s), nil +} + +func NewNullInt64(i int64) NullInt64 { + return NullInt64{Val: Int64(i), Valid: true} +} + +type Int64 int64 + +func (i *Int64) Scan(value any) error { + switch v := value.(type) { + case int64: + *i = Int64(v) + case float64: + *i = Int64(v) + default: + return errors.Errorf("Int64.Scan: cannot scan type %T into Int64", value) + } + return nil +} + +func (i *Int64) Value() (driver.Value, error) { + return int64(*i), nil +} + +func NewNullInt32(i int32) NullInt32 { + return NullInt32{Val: Int32(i), Valid: true} +} + +type Int32 int32 + +func (i *Int32) Scan(value any) error { + var i64 Int64 + if err := i64.Scan(value); err != nil { + return err + } + if i64 > math.MaxInt32 { + return errors.Errorf("Int32.Scan: value %x does not fit into int32", i64) + } + *i = Int32(i64) + return nil +} + +func (i *Int32) Value() (driver.Value, error) { + return int64(*i), nil +} + +func NewNullInt(i int) NullInt { + return NullInt{Val: Int(i), Valid: true} +} + +type Int int + +func (i *Int) Scan(value any) error { + var i64 Int64 + if err := i64.Scan(value); err != nil { + return err + } + if i64 > math.MaxInt { + return errors.Errorf("Int.Scan: value %x does not fit into int", value) + } + *i = Int(i64) + return nil +} + +func (i *Int) Value() (driver.Value, error) { + return int64(*i), nil +} + +func NewNullFloat64(f float64) NullFloat64 { + return NullFloat64{Val: Float64(f), Valid: true} +} + +type Float64 float64 + +func (f *Float64) Scan(value any) error { + switch v := value.(type) { + case float64: + *f = Float64(v) + case int64: + *f = Float64(v) + default: + return errors.Errorf("Float64.Scan: cannot scan type %T into Float64", value) + } + return nil +} + +func (f *Float64) Value() (driver.Value, error) { + return float64(*f), nil +} + +func NewNullBool(b bool) NullBool { + return NullBool{Val: Bool(b), Valid: true} +} + +type Bool bool + +func (b *Bool) Scan(value any) error { + switch v := value.(type) { + case bool: + *b = Bool(v) + case int64: + *b = v != 0 + default: + return errors.Errorf("Bool.Scan: cannot scan type %T into Bool", value) + } + return nil +} + +func (b *Bool) Value() (driver.Value, error) { + return bool(*b), nil +} diff --git a/sqlfields/json_raw_message.go b/sqlfields/json_raw_message.go new file mode 100644 index 00000000..f47745dd --- /dev/null +++ b/sqlfields/json_raw_message.go @@ -0,0 +1,38 @@ +package sqlfields + +import ( + "database/sql/driver" + "encoding/json" + + "github.com/pkg/errors" +) + +func NewNullJSONRawMessage(data []byte) NullJSONRawMessage { + return NullJSONRawMessage{Val: data, Valid: true} +} + +type JSONRawMessage json.RawMessage + +func (j *JSONRawMessage) Scan(value any) error { + switch v := value.(type) { + case []byte: + *j = v + case string: + *j = JSONRawMessage(v) + default: + return errors.Errorf("JSONRawMessage.Scan: cannot scan type %T into JSONRawMessage", value) + } + return nil +} + +func (j *JSONRawMessage) Value() (driver.Value, error) { + return []byte(*j), nil +} + +func (j JSONRawMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(json.RawMessage(j)) +} + +func (j *JSONRawMessage) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, (*json.RawMessage)(j)) +} diff --git a/sqlfields/nullable.go b/sqlfields/nullable.go new file mode 100644 index 00000000..e3d973ca --- /dev/null +++ b/sqlfields/nullable.go @@ -0,0 +1,99 @@ +package sqlfields + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + + "github.com/pkg/errors" +) + +type Nullable[T any, pointer interface { + *T + sql.Scanner + driver.Valuer +}] struct { + Val T + Valid bool +} + +// swagger:type string +// swagger:x-nullable true +type NullString = Nullable[String, *String] + +// swagger:type integer +// swagger:x-nullable true +type NullInt = Nullable[Int, *Int] + +// swagger:type integer +// swagger:x-nullable true +type NullInt32 = Nullable[Int32, *Int32] + +// swagger:type integer +// swagger:x-nullable true +type NullInt64 = Nullable[Int64, *Int64] + +// swagger:type number +// swagger:x-nullable true +type NullFloat64 = Nullable[Float64, *Float64] + +// swagger:type boolean +// swagger:x-nullable true +type NullBool = Nullable[Bool, *Bool] + +// swagger:type object +// swagger:x-nullable true +type NullJSONRawMessage = Nullable[JSONRawMessage, *JSONRawMessage] + +// swagger:type string +// swagger:x-nullable true +type NullDuration = Nullable[Duration, *Duration] + +// swagger:type string +// swagger:x-nullable true +type NullTime = Nullable[Time, *Time] + +func (n Nullable[T, pointer]) MarshalJSON() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + return json.Marshal(n.Val) +} + +func (n *Nullable[T, pointer]) UnmarshalJSON(data []byte) error { + if n == nil { + return errors.New("Nullable: UnmarshalJSON on nil pointer") + } + if len(data) == 0 || string(data) == "null" { + var zero T + n.Val, n.Valid = zero, false + return nil + } + err := json.Unmarshal(data, &n.Val) + if err != nil { + return errors.WithStack(err) + } + n.Valid = true + return nil +} + +func (n *Nullable[T, pointer]) Scan(value any) error { + if value == nil { + var zero T + n.Val, n.Valid = zero, false + return nil + } + pValue := any(&n.Val).(pointer) + if err := pValue.Scan(value); err != nil { + return errors.WithStack(err) + } + n.Valid = true + return nil +} + +func (n Nullable[T, pointer]) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return any(&(n.Val)).(pointer).Value() +} diff --git a/sqlfields/nullable_test.go b/sqlfields/nullable_test.go new file mode 100644 index 00000000..f89a1a4f --- /dev/null +++ b/sqlfields/nullable_test.go @@ -0,0 +1,95 @@ +package sqlfields + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testNullJSONEncoding[T any](t *testing.T, value T, expectedJSON string) { + t.Run(fmt.Sprintf("type=%T", value), func(t *testing.T) { + t.Run("case=marshal", func(t *testing.T) { + actual, err := json.Marshal(value) + require.NoError(t, err) + assert.JSONEq(t, expectedJSON, string(actual)) + }) + t.Run("case=unmarshal", func(t *testing.T) { + var other T + require.NoError(t, json.Unmarshal([]byte(expectedJSON), &other)) + assert.EqualValues(t, value, other) + }) + t.Run("case=null", func(t *testing.T) { + var actual, expected T + require.NoError(t, json.Unmarshal([]byte("null"), &actual)) + assert.EqualValues(t, expected, actual) + + raw, err := json.Marshal(expected) + require.NoError(t, err) + assert.JSONEq(t, "null", string(raw)) + }) + }) +} + +func TestNullableJSON(t *testing.T) { + testNullJSONEncoding(t, NewNullString("foo"), `"foo"`) + testNullJSONEncoding(t, NewNullInt(123), `123`) + testNullJSONEncoding(t, NewNullInt32(456), `456`) + testNullJSONEncoding(t, NewNullInt64(789), `789`) + testNullJSONEncoding(t, NewNullFloat64(1.23), `1.23`) + testNullJSONEncoding(t, NewNullBool(true), `true`) + testNullJSONEncoding(t, NewNullDuration(10*time.Second), `"10s"`) + testNullJSONEncoding(t, NewNullTime(time.Unix(123, 0).UTC()), `"1970-01-01T00:02:03Z"`) + testNullJSONEncoding(t, NewNullJSONRawMessage([]byte(`{"foo":"bar"}`)), `{"foo":"bar"}`) +} + +func testNullSQLCompatibility[T any](t *testing.T, db *sqlx.DB, value T) { + insertValue := func(t *testing.T, value T) int64 { + res, err := db.Exec(`INSERT INTO "testing" ("value") VALUES (?)`, value) + require.NoError(t, err) + id, err := res.LastInsertId() + require.NoError(t, err) + return id + } + + t.Run(fmt.Sprintf("type=%T", value), func(t *testing.T) { + t.Run("case=insert and select non-null values", func(t *testing.T) { + var actual T + require.NoError(t, db.Get(&actual, `SELECT "value" FROM "testing" WHERE "id" = ?`, insertValue(t, value))) + assert.EqualValues(t, value, actual) + }) + + t.Run("case=insert and select null values", func(t *testing.T) { + var actual, null T + require.NoError(t, db.Get(&actual, `SELECT "value" FROM "testing" WHERE "id" = ?`, insertValue(t, null))) + assert.Equal(t, null, actual) + }) + }) +} + +func TestNullableSQL(t *testing.T) { + db, err := sqlx.Connect("sqlite3", "file::memory:?cache=shared") + require.NoError(t, err) + defer db.Close() + + // You have to hate the inconsistencies of SQLite. But for this test, it's great to have column that takes any data type. + _, err = db.Exec(`CREATE TABLE "testing" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "value" BLOB +)`) + require.NoError(t, err) + + testNullSQLCompatibility(t, db, NewNullString("foo")) + testNullSQLCompatibility(t, db, NewNullInt(123)) + testNullSQLCompatibility(t, db, NewNullInt32(456)) + testNullSQLCompatibility(t, db, NewNullInt64(789)) + testNullSQLCompatibility(t, db, NewNullFloat64(1.23)) + testNullSQLCompatibility(t, db, NewNullBool(true)) + testNullSQLCompatibility(t, db, NewNullDuration(10*time.Second)) + testNullSQLCompatibility(t, db, NewNullJSONRawMessage([]byte(`{"foo":"bar"}`))) +} diff --git a/sqlfields/stringslice.go b/sqlfields/stringslice.go new file mode 100644 index 00000000..4bc46373 --- /dev/null +++ b/sqlfields/stringslice.go @@ -0,0 +1,69 @@ +package sqlfields + +import ( + "database/sql/driver" + "encoding/json" + "strings" + + "github.com/pkg/errors" +) + +type StringSliceJSONFormat []string + +type StringSlicePipeDelimiter []string + +func (s *StringSlicePipeDelimiter) Scan(value any) error { + switch v := value.(type) { + case string: + *s = scanStringSlice('|', v) + case []byte: + *s = scanStringSlice('|', string(v)) + default: + return errors.Errorf("StringSlicePipeDelimiter.Scan: cannot scan type %T into StringSlicePipeDelimiter", value) + } + return nil +} + +func (s StringSlicePipeDelimiter) Value() (driver.Value, error) { + return valueStringSlice('|', s), nil +} + +func scanStringSlice(delimiter rune, value string) []string { + escaped := false + splitted := strings.FieldsFunc(value, func(r rune) bool { + if r == '\\' { + escaped = !escaped + } else if escaped && r != delimiter { + escaped = false + } + return !escaped && r == delimiter + }) + for k, v := range splitted { + splitted[k] = strings.ReplaceAll(v, "\\"+string(delimiter), string(delimiter)) + } + return splitted +} + +func valueStringSlice(delimiter rune, value []string) string { + replace := make([]string, len(value)) + for k, v := range value { + replace[k] = strings.ReplaceAll(v, string(delimiter), "\\"+string(delimiter)) + } + return strings.Join(replace, string(delimiter)) +} + +func (s *StringSliceJSONFormat) Scan(value any) error { + switch v := value.(type) { + case string: + return errors.WithStack(json.Unmarshal([]byte(v), s)) + case []byte: + return errors.WithStack(json.Unmarshal(v, s)) + default: + return errors.Errorf("StringSliceJSONFormat.Scan: cannot scan type %T into StringSliceJSONFormat", value) + } +} + +func (s StringSliceJSONFormat) Value() (driver.Value, error) { + b, err := json.Marshal(s) + return string(b), errors.WithStack(err) +} diff --git a/sqlfields/time.go b/sqlfields/time.go new file mode 100644 index 00000000..7628f117 --- /dev/null +++ b/sqlfields/time.go @@ -0,0 +1,103 @@ +package sqlfields + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "time" + + "github.com/pkg/errors" +) + +type Duration time.Duration + +type Time time.Time + +func NewNullTime(t time.Time) NullTime { + return NullTime{Val: Time(t), Valid: true} +} + +func (t *Time) Scan(value any) error { + fmt.Printf("Scanning %#v\n", value) + switch v := value.(type) { + case time.Time: + *t = Time(v) + case string: + parsed, err := time.Parse(time.RFC3339, v) + if err != nil { + return errors.WithStack(err) + } + *t = Time(parsed) + default: + return errors.Errorf("Time.Scan: cannot scan type %T into Time", value) + } + return nil +} + +func (t Time) Value() (driver.Value, error) { + fmt.Printf("Valuing %s\n", time.Time(t).Format(time.RFC3339)) + return time.Time(t).Format(time.RFC3339), nil +} + +func (t Time) MarshalJSON() ([]byte, error) { + return (time.Time)(t).UTC().MarshalJSON() +} + +func (t *Time) UnmarshalJSON(data []byte) error { + var st time.Time + if err := json.Unmarshal(data, &st); err != nil { + return err + } + *t = Time(st) + return nil +} + +func NewNullDuration(d time.Duration) NullDuration { + return NullDuration{Val: Duration(d), Valid: true} +} + +func (d *Duration) Scan(value any) error { + switch v := value.(type) { + case time.Duration: + *d = Duration(v) + case int64: + *d = Duration(v) + case string: + parsed, err := time.ParseDuration(v) + if err != nil { + return errors.WithStack(err) + } + *d = Duration(parsed) + default: + return errors.Errorf("Duration.Scan: cannot scan type %T into Duration", value) + } + return nil +} + +func (d Duration) Value() (driver.Value, error) { + return int64(d), nil +} + +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Duration(d).String()) +} + +func (d *Duration) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return errors.WithStack(err) + } + if len(s) == 0 { + // set to zero value + *d = 0 + return nil + } + + p, err := time.ParseDuration(s) + if err != nil { + return errors.WithStack(err) + } + + *d = Duration(p) + return nil +} diff --git a/sqlxx/types.go b/sqlxx/types.go index 6e94f4e6..840a129c 100644 --- a/sqlxx/types.go +++ b/sqlxx/types.go @@ -40,6 +40,7 @@ func (ns *Duration) UnmarshalJSON(data []byte) error { } // StringSliceJSONFormat represents []string{} which is encoded to/from JSON for SQL storage. +// Deprecated: use sqlfields.StringSliceJSONFormat instead type StringSliceJSONFormat []string // Scan implements the Scanner interface. @@ -69,6 +70,7 @@ func (m StringSliceJSONFormat) Value() (driver.Value, error) { } // StringSlicePipeDelimiter de/encodes the string slice to/from a SQL string. +// Deprecated: use sqlfields.StringSlicePipeDelimiter instead type StringSlicePipeDelimiter []string // Scan implements the Scanner interface. @@ -111,10 +113,11 @@ func valueStringSlice(delimiter rune, value []string) string { return strings.Join(replace, string(delimiter)) } -// NullBool represents a bool that may be null. +// NullBool represents a bool that may be sqlfields. // NullBool implements the Scanner interface so // swagger:type bool // swagger:model nullBool +// Deprecated: use sqlfields.Bool instead type NullBool struct { Bool bool Valid bool // Valid is true if Bool is not NULL @@ -162,6 +165,7 @@ func (ns *NullBool) UnmarshalJSON(data []byte) error { // swagger:type string // swagger:model nullString +// Deprecated: use sqlfields.String instead type NullString string // MarshalJSON returns m as the JSON encoding of m. @@ -207,6 +211,7 @@ func (ns NullString) String() string { // // swagger:model nullTime // required: false +// Deprecated: use sqlfields.Time instead type NullTime time.Time // Scan implements the Scanner interface. @@ -337,6 +342,7 @@ func (m *JSONRawMessage) UnmarshalJSON(data []byte) error { // NullJSONRawMessage represents a json.RawMessage that works well with JSON, SQL, and Swagger and is NULLable- // // swagger:model nullJsonRawMessage +// Deprecated: use sqlfields.JSONRawMessage instead type NullJSONRawMessage json.RawMessage // Scan implements the Scanner interface. @@ -396,8 +402,9 @@ func JSONValue(src interface{}) (driver.Value, error) { return b.String(), nil } -// NullInt64 represents an int64 that may be null. +// NullInt64 represents an int64 that may be sqlfields. // swagger:model nullInt64 +// Deprecated: use sqlfields.Int64 instead type NullInt64 struct { Int int64 Valid bool // Valid is true if Duration is not NULL @@ -447,6 +454,7 @@ func (ns *NullInt64) UnmarshalJSON(data []byte) error { // // swagger:type string // swagger:model nullDuration +// Deprecated: use sqlfields.Duration instead type NullDuration struct { Duration time.Duration Valid bool From a3d1bc75439f202a711997784375624d4e1ce5ff Mon Sep 17 00:00:00 2001 From: Patrik Date: Tue, 25 Oct 2022 14:57:05 +0200 Subject: [PATCH 2/3] fix(sqlfields): handle nil values in `NewNullJSONRawMessage` --- sqlfields/base_types_test.go | 7 +++++++ sqlfields/json_raw_message.go | 3 +++ 2 files changed, 10 insertions(+) diff --git a/sqlfields/base_types_test.go b/sqlfields/base_types_test.go index e1a1460c..79a22112 100644 --- a/sqlfields/base_types_test.go +++ b/sqlfields/base_types_test.go @@ -100,3 +100,10 @@ func TestSQLCompat(t *testing.T) { testSQLCompatibility(t, db, StringSliceJSONFormat{"foo", "bar"}) testSQLCompatibility(t, db, StringSlicePipeDelimiter{"foo", "bar"}) } + +func TestFactoryFuncs(t *testing.T) { + t.Run("case=JSONRawMessage", func(t *testing.T) { + assert.True(t, NewNullJSONRawMessage([]byte(`"foo"`)).Valid) + assert.False(t, NewNullJSONRawMessage(nil).Valid) + }) +} diff --git a/sqlfields/json_raw_message.go b/sqlfields/json_raw_message.go index f47745dd..c94622e9 100644 --- a/sqlfields/json_raw_message.go +++ b/sqlfields/json_raw_message.go @@ -8,6 +8,9 @@ import ( ) func NewNullJSONRawMessage(data []byte) NullJSONRawMessage { + if data == nil { + return NullJSONRawMessage{} + } return NullJSONRawMessage{Val: data, Valid: true} } From fa2a63670069165b1a6135c664df851359d0bed7 Mon Sep 17 00:00:00 2001 From: Patrik Date: Tue, 25 Oct 2022 15:31:52 +0200 Subject: [PATCH 3/3] feat(pointerx): generic pointer and de-reference functions --- pointerx/pointerx.go | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/pointerx/pointerx.go b/pointerx/pointerx.go index 00ad3fbf..ce99294b 100644 --- a/pointerx/pointerx.go +++ b/pointerx/pointerx.go @@ -1,11 +1,27 @@ package pointerx +// Ptr returns the input value's pointer. +func Ptr[T any](v T) *T { + return &v +} + +// Deref returns the input values de-referenced value, or zero value if nil. +func Deref[T any](p *T) T { + if p == nil { + var zero T + return zero + } + return *p +} + // String returns the input value's pointer. +// Deprecated: use Ptr instead. func String(s string) *string { return &s } // StringR is the reverse to String. +// Deprecated: use Deref instead. func StringR(s *string) string { if s == nil { return "" @@ -14,11 +30,13 @@ func StringR(s *string) string { } // Int returns the input value's pointer. +// Deprecated: use Ptr instead. func Int(s int) *int { return &s } // IntR is the reverse to Int. +// Deprecated: use Deref instead. func IntR(s *int) int { if s == nil { return int(0) @@ -27,11 +45,13 @@ func IntR(s *int) int { } // Int32 returns the input value's pointer. -func Int32(s int32) *int32 { +// Deprecated: use Ptr instead. +func Int32[T any](s int32) *int32 { return &s } // Int32R is the reverse to Int32. +// Deprecated: use Deref instead. func Int32R(s *int32) int32 { if s == nil { return int32(0) @@ -40,11 +60,13 @@ func Int32R(s *int32) int32 { } // Int64 returns the input value's pointer. +// Deprecated: use Ptr instead. func Int64(s int64) *int64 { return &s } // Int64R is the reverse to Int64. +// Deprecated: use Deref instead. func Int64R(s *int64) int64 { if s == nil { return int64(0) @@ -53,11 +75,13 @@ func Int64R(s *int64) int64 { } // Float32 returns the input value's pointer. +// Deprecated: use Ptr instead. func Float32(s float32) *float32 { return &s } // Float32R is the reverse to Float32. +// Deprecated: use Deref instead. func Float32R(s *float32) float32 { if s == nil { return float32(0) @@ -66,11 +90,13 @@ func Float32R(s *float32) float32 { } // Float64 returns the input value's pointer. +// Deprecated: use Ptr instead. func Float64(s float64) *float64 { return &s } // Float64R is the reverse to Float64. +// Deprecated: use Deref instead. func Float64R(s *float64) float64 { if s == nil { return float64(0) @@ -79,11 +105,13 @@ func Float64R(s *float64) float64 { } // Bool returns the input value's pointer. +// Deprecated: use Ptr instead. func Bool(s bool) *bool { return &s } // BoolR is the reverse to Bool. +// Deprecated: use Deref instead. func BoolR(s *bool) bool { if s == nil { return false