Skip to content

Commit

Permalink
FIX sql rune iterator method for postgres quote.
Browse files Browse the repository at this point in the history
  • Loading branch information
karminski committed Feb 4, 2024
1 parent 81d562f commit df3c745
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
29 changes: 18 additions & 11 deletions src/utils/parser/sql/escaper.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,24 +288,31 @@ func (sqlEscaper *SQLEscaper) EscapeSQLActionTemplate(sql string, args map[strin
ret.WriteString("}")
escapedBracketWithVariable = ""
}
// init sql rune serial list
sqlRuneSerialList := make([]int, 0)
sqlRuneList := []rune(sql)
// convert to rune slice
for i, _ := range sqlRuneList {
sqlRuneSerialList = append(sqlRuneSerialList, i)
}

// get next char method.
getNextChar := func(serial int) (rune, error) {
if len(sql)-1 <= serial {
if len(sqlRuneSerialList)-1 <= serial {
return rune(0), errors.New("over range")
}
return rune(sql[serial+1]), nil
return sqlRuneList[sqlRuneSerialList[serial+1]], nil
}

// convert to rune slice
sqlRuneList := make([]rune, 0)
for _, j := range sql {
sqlRuneList = append(sqlRuneList, j)
}

charSerial := 0
for _, c := range sql {
charSerial := -1
for {
charSerial++
if charSerial > len(sqlRuneSerialList)-1 {
break
}
c := sqlRuneList[sqlRuneSerialList[charSerial]]

// fmt.Printf("[%d] char: %s\n", charSerial, string(c))
fmt.Printf("[%d:%d] char: %s\n", sqlRuneSerialList[charSerial], charSerial, string(c))
// process bracket
// '' + '{' or '{' + '{'
if c == '{' && leftBraketCounter <= 1 {
Expand Down
6 changes: 3 additions & 3 deletions src/utils/parser/sql/escaper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,15 @@ func TestEscapeMySQLSQLInStatementQueryInIntStringInUnsafeMode(t *testing.T) {
// SELECT * FROM apps WHERE uid = ANY (VALUES ('ca5e3145-f9b4-4610-bd25-0ffbf258cce7'::uuid), ('feb398fa-e5eb-43f6-8488-82a9f4806570'::uuid));
// ```
func TestEscapePostgresSQLAnyStatementQuery(t *testing.T) {
sql_1 := `select * from users where name = ANY('{{multiselect1.value.map(b => Number(b))}}')`
sql_1 := `select * from users where name = ANY(ARRAY[{{multiselect1.value.map(b => Number(b))}}])`
args := map[string]interface{}{
`multiselect1.value.map(b => Number(b))`: []interface{}{"a", "b", "c"},
}
sqlEscaper := NewSQLEscaper(resourcelist.TYPE_POSTGRESQL_ID)
escapedSQL, usedArgs, errInEscape := sqlEscaper.EscapeSQLActionTemplate(sql_1, args, true)
assert.Nil(t, errInEscape)
assert.Equal(t, []interface{}{}, usedArgs, "the usedArgs should be equal")
assert.Equal(t, "select * from users where id in ('a', 'b', 'c')", escapedSQL, "the token should be equal")
assert.Equal(t, []interface{}{"a", "b", "c"}, usedArgs, "the usedArgs should be equal")
assert.Equal(t, "select * from users where name = ANY(ARRAY[$1, $2, $3])", escapedSQL, "the token should be equal")
}

func TestEscapePostgresSQLInvaliedLengthUTF8Case(t *testing.T) {
Expand Down

0 comments on commit df3c745

Please sign in to comment.