diff --git a/src/utils/parser/sql/escaper.go b/src/utils/parser/sql/escaper.go index bd8eabee..61c6ee5a 100644 --- a/src/utils/parser/sql/escaper.go +++ b/src/utils/parser/sql/escaper.go @@ -288,24 +288,31 @@ func (sqlEscaper *SQLEscaper) EscapeSQLActionTemplate(sql string, args map[strin ret.WriteString("}") escapedBracketWithVariable = "" } + // init sql rune serial list + sqlRuneSerialList := make([]int, 0) + sqlRuneList := []rune(sql) + // convert to rune slice + for i, _ := range sqlRuneList { + sqlRuneSerialList = append(sqlRuneSerialList, i) + } + + // get next char method. getNextChar := func(serial int) (rune, error) { - if len(sql)-1 <= serial { + if len(sqlRuneSerialList)-1 <= serial { return rune(0), errors.New("over range") } - return rune(sql[serial+1]), nil + return sqlRuneList[sqlRuneSerialList[serial+1]], nil } - // convert to rune slice - sqlRuneList := make([]rune, 0) - for _, j := range sql { - sqlRuneList = append(sqlRuneList, j) - } - - charSerial := 0 - for _, c := range sql { + charSerial := -1 + for { charSerial++ + if charSerial > len(sqlRuneSerialList)-1 { + break + } + c := sqlRuneList[sqlRuneSerialList[charSerial]] - // fmt.Printf("[%d] char: %s\n", charSerial, string(c)) + fmt.Printf("[%d:%d] char: %s\n", sqlRuneSerialList[charSerial], charSerial, string(c)) // process bracket // '' + '{' or '{' + '{' if c == '{' && leftBraketCounter <= 1 { diff --git a/src/utils/parser/sql/escaper_test.go b/src/utils/parser/sql/escaper_test.go index 83f1bf9f..e4249a78 100644 --- a/src/utils/parser/sql/escaper_test.go +++ b/src/utils/parser/sql/escaper_test.go @@ -295,15 +295,15 @@ func TestEscapeMySQLSQLInStatementQueryInIntStringInUnsafeMode(t *testing.T) { // SELECT * FROM apps WHERE uid = ANY (VALUES ('ca5e3145-f9b4-4610-bd25-0ffbf258cce7'::uuid), ('feb398fa-e5eb-43f6-8488-82a9f4806570'::uuid)); // ``` func TestEscapePostgresSQLAnyStatementQuery(t *testing.T) { - sql_1 := `select * from users where name = ANY('{{multiselect1.value.map(b => Number(b))}}')` + sql_1 := `select * from users where name = ANY(ARRAY[{{multiselect1.value.map(b => Number(b))}}])` args := map[string]interface{}{ `multiselect1.value.map(b => Number(b))`: []interface{}{"a", "b", "c"}, } sqlEscaper := NewSQLEscaper(resourcelist.TYPE_POSTGRESQL_ID) escapedSQL, usedArgs, errInEscape := sqlEscaper.EscapeSQLActionTemplate(sql_1, args, true) assert.Nil(t, errInEscape) - assert.Equal(t, []interface{}{}, usedArgs, "the usedArgs should be equal") - assert.Equal(t, "select * from users where id in ('a', 'b', 'c')", escapedSQL, "the token should be equal") + assert.Equal(t, []interface{}{"a", "b", "c"}, usedArgs, "the usedArgs should be equal") + assert.Equal(t, "select * from users where name = ANY(ARRAY[$1, $2, $3])", escapedSQL, "the token should be equal") } func TestEscapePostgresSQLInvaliedLengthUTF8Case(t *testing.T) {