Skip to content

Commit

Permalink
queryx: add binding transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
N1cOs authored and mmatczuk committed Nov 26, 2021
1 parent 5e98fb6 commit 504f652
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 1 deletion.
20 changes: 19 additions & 1 deletion queryx.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ type Queryx struct {
*gocql.Query
Names []string
Mapper *reflectx.Mapper
err error

tr Transformer
err error
}

// Query creates a new Queryx from gocql.Query using a default mapper.
Expand All @@ -104,9 +106,17 @@ func Query(q *gocql.Query, names []string) *Queryx {
Query: q,
Names: names,
Mapper: DefaultMapper,
tr: DefaultBindTransformer,
}
}

// WithBindTransformer sets the query bind transformer.
// The transformer is called right before binding a value to a named parameter.
func (q *Queryx) WithBindTransformer(tr Transformer) *Queryx {
q.tr = tr
return q
}

// BindStruct binds query named parameters to values from arg using mapper. If
// value cannot be found error is reported.
func (q *Queryx) BindStruct(arg interface{}) *Queryx {
Expand Down Expand Up @@ -157,6 +167,10 @@ func (q *Queryx) bindStructArgs(arg0 interface{}, arg1 map[string]interface{}) (
arglist = append(arglist, val)
}

if q.tr != nil {
arglist[i] = q.tr(q.Names[i], arglist[i])
}

return nil
})

Expand Down Expand Up @@ -184,6 +198,10 @@ func (q *Queryx) bindMapArgs(arg map[string]interface{}) ([]interface{}, error)
if !ok {
return arglist, fmt.Errorf("could not find name %q in %#v", name, arg)
}

if q.tr != nil {
val = q.tr(name, val)
}
arglist = append(arglist, val)
}
return arglist, nil
Expand Down
60 changes: 60 additions & 0 deletions queryx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ func TestQueryxBindStruct(t *testing.T) {
}
})

t.Run("with transformer", func(t *testing.T) {
tr := func(name string, val interface{}) interface{} {
if name == "age" {
return 42
}
return val
}

names := []string{"name", "age", "first", "last"}
args, err := Query(nil, names).WithBindTransformer(tr).bindStructArgs(v, nil)
if err != nil {
t.Fatal(err)
}

if diff := cmp.Diff(args, []interface{}{"name", 42, "first", "last"}); diff != "" {
t.Error("args mismatch", diff)
}
})

t.Run("error", func(t *testing.T) {
names := []string{"name", "age", "first", "not_found"}
_, err := Query(nil, names).bindStructArgs(v, nil)
Expand All @@ -111,6 +130,28 @@ func TestQueryxBindStruct(t *testing.T) {
}
})

t.Run("fallback with transformer", func(t *testing.T) {
tr := func(name string, val interface{}) interface{} {
if name == "not_found" {
return "map_found"
}
return val
}

names := []string{"name", "age", "first", "not_found"}
m := map[string]interface{}{
"not_found": "last",
}
args, err := Query(nil, names).WithBindTransformer(tr).bindStructArgs(v, m)
if err != nil {
t.Fatal(err)
}

if diff := cmp.Diff(args, []interface{}{"name", 30, "first", "map_found"}); diff != "" {
t.Error("args mismatch", diff)
}
})

t.Run("fallback error", func(t *testing.T) {
names := []string{"name", "age", "first", "not_found", "really_not_found"}
m := map[string]interface{}{
Expand Down Expand Up @@ -143,6 +184,25 @@ func TestQueryxBindMap(t *testing.T) {
}
})

t.Run("with transformer", func(t *testing.T) {
tr := func(name string, val interface{}) interface{} {
if name == "age" {
return 42
}
return val
}

names := []string{"name", "age", "first", "last"}
args, err := Query(nil, names).WithBindTransformer(tr).bindMapArgs(v)
if err != nil {
t.Fatal(err)
}

if diff := cmp.Diff(args, []interface{}{"name", 42, "first", "last"}); diff != "" {
t.Error("args mismatch", diff)
}
})

t.Run("error", func(t *testing.T) {
names := []string{"name", "first", "not_found"}
_, err := Query(nil, names).bindMapArgs(v)
Expand Down
2 changes: 2 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func (s Session) ContextQuery(ctx context.Context, stmt string, names []string)
Query: s.Session.Query(stmt).WithContext(ctx),
Names: names,
Mapper: s.Mapper,
tr: DefaultBindTransformer,
}
}

Expand All @@ -62,6 +63,7 @@ func (s Session) Query(stmt string, names []string) *Queryx {
Query: s.Session.Query(stmt),
Names: names,
Mapper: s.Mapper,
tr: DefaultBindTransformer,
}
}

Expand Down
13 changes: 13 additions & 0 deletions transformer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (C) 2017 ScyllaDB
// Use of this source code is governed by a ALv2-style
// license that can be found in the LICENSE file.

package gocqlx

// Transformer transforms the value of the named parameter to another value.
type Transformer func(name string, val interface{}) interface{}

// DefaultBindTransformer just do nothing.
//
// A custom transformer can always be set per Query.
var DefaultBindTransformer Transformer

0 comments on commit 504f652

Please sign in to comment.