Skip to content

Commit 66edfad

Browse files
committed
nhance interpolateParams to correctly handle placeholders in queries with comments, strings, and backticks.
* Add `findParamPositions` to identify real parameter positions * Update and expand related tests.
1 parent 76c00e3 commit 66edfad

File tree

4 files changed

+313
-197
lines changed

4 files changed

+313
-197
lines changed

connection.go

Lines changed: 149 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172
}
173173

174174
// Closes the network connection and unsets internal variables. Do not call this
175-
// function after successfully authentication, call Close instead. This function
175+
// function after successful authentication, call Close instead. This function
176176
// is called before auth or on auth failure because MySQL will have already
177177
// closed the network connection.
178178
func (mc *mysqlConn) cleanup() {
@@ -246,100 +246,172 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
246246
}
247247

248248
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
249-
// Number of ? should be same to len(args)
250-
if strings.Count(query, "?") != len(args) {
251-
return "", driver.ErrSkip
252-
}
249+
noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
250+
const (
251+
stateNormal = iota
252+
stateString
253+
stateEscape
254+
stateEOLComment
255+
stateSlashStarComment
256+
stateBacktick
257+
)
258+
259+
var (
260+
QUOTE_BYTE = byte('\'')
261+
DBL_QUOTE_BYTE = byte('"')
262+
BACKSLASH_BYTE = byte('\\')
263+
QUESTION_MARK_BYTE = byte('?')
264+
SLASH_BYTE = byte('/')
265+
STAR_BYTE = byte('*')
266+
HASH_BYTE = byte('#')
267+
MINUS_BYTE = byte('-')
268+
LINE_FEED_BYTE = byte('\n')
269+
RADICAL_BYTE = byte('`')
270+
)
253271

254272
buf, err := mc.buf.takeCompleteBuffer()
255273
if err != nil {
256-
// can not take the buffer. Something must be wrong with the connection
257274
mc.cleanup()
258-
// interpolateParams would be called before sending any query.
259-
// So its safe to retry.
260275
return "", driver.ErrBadConn
261276
}
262277
buf = buf[:0]
278+
state := stateNormal
279+
singleQuotes := false
280+
lastChar := byte(0)
263281
argPos := 0
264-
265-
for i := 0; i < len(query); i++ {
266-
q := strings.IndexByte(query[i:], '?')
267-
if q == -1 {
268-
buf = append(buf, query[i:]...)
269-
break
270-
}
271-
buf = append(buf, query[i:i+q]...)
272-
i += q
273-
274-
arg := args[argPos]
275-
argPos++
276-
277-
if arg == nil {
278-
buf = append(buf, "NULL"...)
282+
lenQuery := len(query)
283+
lastIdx := 0
284+
285+
for i := 0; i < lenQuery; i++ {
286+
currentChar := query[i]
287+
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
288+
state = stateString
289+
lastChar = currentChar
279290
continue
280291
}
281-
282-
switch v := arg.(type) {
283-
case int64:
284-
buf = strconv.AppendInt(buf, v, 10)
285-
case uint64:
286-
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
287-
buf = strconv.AppendUint(buf, v, 10)
288-
case float64:
289-
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
290-
case bool:
291-
if v {
292-
buf = append(buf, '1')
293-
} else {
294-
buf = append(buf, '0')
292+
switch currentChar {
293+
case STAR_BYTE:
294+
if state == stateNormal && lastChar == SLASH_BYTE {
295+
state = stateSlashStarComment
295296
}
296-
case time.Time:
297-
if v.IsZero() {
298-
buf = append(buf, "'0000-00-00'"...)
299-
} else {
300-
buf = append(buf, '\'')
301-
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
302-
if err != nil {
303-
return "", err
304-
}
305-
buf = append(buf, '\'')
297+
case SLASH_BYTE:
298+
if state == stateSlashStarComment && lastChar == STAR_BYTE {
299+
state = stateNormal
306300
}
307-
case json.RawMessage:
308-
buf = append(buf, '\'')
309-
if mc.status&statusNoBackslashEscapes == 0 {
310-
buf = escapeBytesBackslash(buf, v)
311-
} else {
312-
buf = escapeBytesQuotes(buf, v)
301+
case HASH_BYTE:
302+
if state == stateNormal {
303+
state = stateEOLComment
313304
}
314-
buf = append(buf, '\'')
315-
case []byte:
316-
if v == nil {
317-
buf = append(buf, "NULL"...)
318-
} else {
319-
buf = append(buf, "_binary'"...)
320-
if mc.status&statusNoBackslashEscapes == 0 {
321-
buf = escapeBytesBackslash(buf, v)
322-
} else {
323-
buf = escapeBytesQuotes(buf, v)
324-
}
325-
buf = append(buf, '\'')
305+
case MINUS_BYTE:
306+
if state == stateNormal && lastChar == MINUS_BYTE {
307+
state = stateEOLComment
326308
}
327-
case string:
328-
buf = append(buf, '\'')
329-
if mc.status&statusNoBackslashEscapes == 0 {
330-
buf = escapeStringBackslash(buf, v)
331-
} else {
332-
buf = escapeStringQuotes(buf, v)
309+
case LINE_FEED_BYTE:
310+
if state == stateEOLComment {
311+
state = stateNormal
333312
}
334-
buf = append(buf, '\'')
335-
default:
336-
return "", driver.ErrSkip
337-
}
313+
case DBL_QUOTE_BYTE:
314+
if state == stateNormal {
315+
state = stateString
316+
singleQuotes = false
317+
} else if state == stateString && !singleQuotes {
318+
state = stateNormal
319+
} else if state == stateEscape {
320+
state = stateString
321+
}
322+
case QUOTE_BYTE:
323+
if state == stateNormal {
324+
state = stateString
325+
singleQuotes = true
326+
} else if state == stateString && singleQuotes {
327+
state = stateNormal
328+
} else if state == stateEscape {
329+
state = stateString
330+
}
331+
case BACKSLASH_BYTE:
332+
if state == stateString && !noBackslashEscapes {
333+
state = stateEscape
334+
}
335+
case QUESTION_MARK_BYTE:
336+
if state == stateNormal {
337+
if argPos >= len(args) {
338+
return "", driver.ErrSkip
339+
}
340+
buf = append(buf, query[lastIdx:i]...)
341+
arg := args[argPos]
342+
argPos++
343+
344+
if arg == nil {
345+
buf = append(buf, "NULL"...)
346+
lastIdx = i + 1
347+
break
348+
}
349+
350+
switch v := arg.(type) {
351+
case int64:
352+
buf = strconv.AppendInt(buf, v, 10)
353+
case uint64:
354+
buf = strconv.AppendUint(buf, v, 10)
355+
case float64:
356+
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
357+
case bool:
358+
if v {
359+
buf = append(buf, '1')
360+
} else {
361+
buf = append(buf, '0')
362+
}
363+
case time.Time:
364+
if v.IsZero() {
365+
buf = append(buf, "'0000-00-00'"...)
366+
} else {
367+
buf = append(buf, '\'')
368+
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
369+
if err != nil {
370+
return "", err
371+
}
372+
buf = append(buf, '\'')
373+
}
374+
case json.RawMessage:
375+
if noBackslashEscapes {
376+
buf = escapeBytesQuotes(buf, v, false)
377+
} else {
378+
buf = escapeBytesBackslash(buf, v, false)
379+
}
380+
case []byte:
381+
if v == nil {
382+
buf = append(buf, "NULL"...)
383+
} else {
384+
if noBackslashEscapes {
385+
buf = escapeBytesQuotes(buf, v, true)
386+
} else {
387+
buf = escapeBytesBackslash(buf, v, true)
388+
}
389+
}
390+
case string:
391+
if noBackslashEscapes {
392+
buf = escapeStringQuotes(buf, v)
393+
} else {
394+
buf = escapeStringBackslash(buf, v)
395+
}
396+
default:
397+
return "", driver.ErrSkip
398+
}
338399

339-
if len(buf)+4 > mc.maxAllowedPacket {
340-
return "", driver.ErrSkip
400+
if len(buf)+4 > mc.maxAllowedPacket {
401+
return "", driver.ErrSkip
402+
}
403+
lastIdx = i + 1
404+
}
405+
case RADICAL_BYTE:
406+
if state == stateBacktick {
407+
state = stateNormal
408+
} else if state == stateNormal {
409+
state = stateBacktick
410+
}
341411
}
412+
lastChar = currentChar
342413
}
414+
buf = append(buf, query[lastIdx:]...)
343415
if argPos != len(args) {
344416
return "", driver.ErrSkip
345417
}

connection_test.go

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
7979
}
8080
}
8181

82-
// We don't support placeholder in string literal for now.
83-
// https://github.com/go-sql-driver/mysql/pull/490
84-
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
85-
mc := &mysqlConn{
86-
buf: newBuffer(),
87-
maxAllowedPacket: maxPacketSize,
88-
cfg: &Config{
89-
InterpolateParams: true,
90-
},
91-
}
92-
93-
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
94-
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
95-
if err != driver.ErrSkip {
96-
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
97-
}
98-
}
99-
10082
func TestInterpolateParamsUint64(t *testing.T) {
10183
mc := &mysqlConn{
10284
buf: newBuffer(),
@@ -204,3 +186,55 @@ func (bc badConnection) Write(b []byte) (n int, err error) {
204186
func (bc badConnection) Close() error {
205187
return nil
206188
}
189+
190+
func TestInterpolateParamsWithComments(t *testing.T) {
191+
mc := &mysqlConn{
192+
buf: newBuffer(),
193+
maxAllowedPacket: maxPacketSize,
194+
cfg: &Config{
195+
InterpolateParams: true,
196+
},
197+
}
198+
199+
tests := []struct {
200+
query string
201+
args []driver.Value
202+
expected string
203+
shouldSkip bool
204+
}{
205+
// ? in single-line comment (--) should not be replaced
206+
{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
207+
// ? in single-line comment (#) should not be replaced
208+
{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
209+
// ? in multi-line comment should not be replaced
210+
{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
211+
// ? in string literal should not be replaced
212+
{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
213+
// ? in backtick identifier should not be replaced
214+
{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
215+
// ? in backslash-escaped string literal should not be replaced
216+
{"SELECT 'C:\\path\\?x.txt', ?", []driver.Value{int64(42)}, "SELECT 'C:\\path\\?x.txt', 42", false},
217+
// ? in backslash-escaped string literal should not be replaced
218+
{"SELECT '\\'?', col FROM tbl WHERE id = ? AND desc = 'foo\\'bar?'", []driver.Value{int64(42)}, "SELECT '\\'?', col FROM tbl WHERE id = 42 AND desc = 'foo\\'bar?'", false},
219+
// Multiple comments and real placeholders
220+
{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
221+
}
222+
223+
for i, test := range tests {
224+
225+
q, err := mc.interpolateParams(test.query, test.args)
226+
if test.shouldSkip {
227+
if err != driver.ErrSkip {
228+
t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
229+
}
230+
continue
231+
}
232+
if err != nil {
233+
t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
234+
continue
235+
}
236+
if q != test.expected {
237+
t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
238+
}
239+
}
240+
}

0 commit comments

Comments
 (0)