From 504f6523d9a13cc7670d56d4c194293cf7c668bf Mon Sep 17 00:00:00 2001 From: Nikita Karmatskikh Date: Tue, 23 Nov 2021 21:58:26 +0300 Subject: [PATCH] queryx: add binding transformer --- queryx.go | 20 ++++++++++++++++- queryx_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++ session.go | 2 ++ transformer.go | 13 +++++++++++ 4 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 transformer.go diff --git a/queryx.go b/queryx.go index 5a38852..175132b 100644 --- a/queryx.go +++ b/queryx.go @@ -93,7 +93,9 @@ type Queryx struct { *gocql.Query Names []string Mapper *reflectx.Mapper - err error + + tr Transformer + err error } // Query creates a new Queryx from gocql.Query using a default mapper. @@ -104,9 +106,17 @@ func Query(q *gocql.Query, names []string) *Queryx { Query: q, Names: names, Mapper: DefaultMapper, + tr: DefaultBindTransformer, } } +// WithBindTransformer sets the query bind transformer. +// The transformer is called right before binding a value to a named parameter. +func (q *Queryx) WithBindTransformer(tr Transformer) *Queryx { + q.tr = tr + return q +} + // BindStruct binds query named parameters to values from arg using mapper. If // value cannot be found error is reported. func (q *Queryx) BindStruct(arg interface{}) *Queryx { @@ -157,6 +167,10 @@ func (q *Queryx) bindStructArgs(arg0 interface{}, arg1 map[string]interface{}) ( arglist = append(arglist, val) } + if q.tr != nil { + arglist[i] = q.tr(q.Names[i], arglist[i]) + } + return nil }) @@ -184,6 +198,10 @@ func (q *Queryx) bindMapArgs(arg map[string]interface{}) ([]interface{}, error) if !ok { return arglist, fmt.Errorf("could not find name %q in %#v", name, arg) } + + if q.tr != nil { + val = q.tr(name, val) + } arglist = append(arglist, val) } return arglist, nil diff --git a/queryx_test.go b/queryx_test.go index 27535b1..649d0d0 100644 --- a/queryx_test.go +++ b/queryx_test.go @@ -88,6 +88,25 @@ func TestQueryxBindStruct(t *testing.T) { } }) + t.Run("with transformer", func(t *testing.T) { + tr := func(name string, val interface{}) interface{} { + if name == "age" { + return 42 + } + return val + } + + names := []string{"name", "age", "first", "last"} + args, err := Query(nil, names).WithBindTransformer(tr).bindStructArgs(v, nil) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(args, []interface{}{"name", 42, "first", "last"}); diff != "" { + t.Error("args mismatch", diff) + } + }) + t.Run("error", func(t *testing.T) { names := []string{"name", "age", "first", "not_found"} _, err := Query(nil, names).bindStructArgs(v, nil) @@ -111,6 +130,28 @@ func TestQueryxBindStruct(t *testing.T) { } }) + t.Run("fallback with transformer", func(t *testing.T) { + tr := func(name string, val interface{}) interface{} { + if name == "not_found" { + return "map_found" + } + return val + } + + names := []string{"name", "age", "first", "not_found"} + m := map[string]interface{}{ + "not_found": "last", + } + args, err := Query(nil, names).WithBindTransformer(tr).bindStructArgs(v, m) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(args, []interface{}{"name", 30, "first", "map_found"}); diff != "" { + t.Error("args mismatch", diff) + } + }) + t.Run("fallback error", func(t *testing.T) { names := []string{"name", "age", "first", "not_found", "really_not_found"} m := map[string]interface{}{ @@ -143,6 +184,25 @@ func TestQueryxBindMap(t *testing.T) { } }) + t.Run("with transformer", func(t *testing.T) { + tr := func(name string, val interface{}) interface{} { + if name == "age" { + return 42 + } + return val + } + + names := []string{"name", "age", "first", "last"} + args, err := Query(nil, names).WithBindTransformer(tr).bindMapArgs(v) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(args, []interface{}{"name", 42, "first", "last"}); diff != "" { + t.Error("args mismatch", diff) + } + }) + t.Run("error", func(t *testing.T) { names := []string{"name", "first", "not_found"} _, err := Query(nil, names).bindMapArgs(v) diff --git a/session.go b/session.go index 66e86da..608a3ae 100644 --- a/session.go +++ b/session.go @@ -49,6 +49,7 @@ func (s Session) ContextQuery(ctx context.Context, stmt string, names []string) Query: s.Session.Query(stmt).WithContext(ctx), Names: names, Mapper: s.Mapper, + tr: DefaultBindTransformer, } } @@ -62,6 +63,7 @@ func (s Session) Query(stmt string, names []string) *Queryx { Query: s.Session.Query(stmt), Names: names, Mapper: s.Mapper, + tr: DefaultBindTransformer, } } diff --git a/transformer.go b/transformer.go new file mode 100644 index 0000000..bb90d0f --- /dev/null +++ b/transformer.go @@ -0,0 +1,13 @@ +// Copyright (C) 2017 ScyllaDB +// Use of this source code is governed by a ALv2-style +// license that can be found in the LICENSE file. + +package gocqlx + +// Transformer transforms the value of the named parameter to another value. +type Transformer func(name string, val interface{}) interface{} + +// DefaultBindTransformer just do nothing. +// +// A custom transformer can always be set per Query. +var DefaultBindTransformer Transformer