@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172
172
}
173
173
174
174
// Closes the network connection and unsets internal variables. Do not call this
175
- // function after successfully authentication, call Close instead. This function
175
+ // function after successful authentication, call Close instead. This function
176
176
// is called before auth or on auth failure because MySQL will have already
177
177
// closed the network connection.
178
178
func (mc * mysqlConn ) cleanup () {
@@ -246,100 +246,172 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
246
246
}
247
247
248
248
func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
249
- // Number of ? should be same to len(args)
250
- if strings .Count (query , "?" ) != len (args ) {
251
- return "" , driver .ErrSkip
252
- }
249
+ noBackslashEscapes := (mc .status & statusNoBackslashEscapes ) != 0
250
+ const (
251
+ stateNormal = iota
252
+ stateString
253
+ stateEscape
254
+ stateEOLComment
255
+ stateSlashStarComment
256
+ stateBacktick
257
+ )
258
+
259
+ var (
260
+ QUOTE_BYTE = byte ('\'' )
261
+ DBL_QUOTE_BYTE = byte ('"' )
262
+ BACKSLASH_BYTE = byte ('\\' )
263
+ QUESTION_MARK_BYTE = byte ('?' )
264
+ SLASH_BYTE = byte ('/' )
265
+ STAR_BYTE = byte ('*' )
266
+ HASH_BYTE = byte ('#' )
267
+ MINUS_BYTE = byte ('-' )
268
+ LINE_FEED_BYTE = byte ('\n' )
269
+ RADICAL_BYTE = byte ('`' )
270
+ )
253
271
254
272
buf , err := mc .buf .takeCompleteBuffer ()
255
273
if err != nil {
256
- // can not take the buffer. Something must be wrong with the connection
257
274
mc .cleanup ()
258
- // interpolateParams would be called before sending any query.
259
- // So its safe to retry.
260
275
return "" , driver .ErrBadConn
261
276
}
262
277
buf = buf [:0 ]
278
+ state := stateNormal
279
+ singleQuotes := false
280
+ lastChar := byte (0 )
263
281
argPos := 0
264
-
265
- for i := 0 ; i < len (query ); i ++ {
266
- q := strings .IndexByte (query [i :], '?' )
267
- if q == - 1 {
268
- buf = append (buf , query [i :]... )
269
- break
270
- }
271
- buf = append (buf , query [i :i + q ]... )
272
- i += q
273
-
274
- arg := args [argPos ]
275
- argPos ++
276
-
277
- if arg == nil {
278
- buf = append (buf , "NULL" ... )
282
+ lenQuery := len (query )
283
+ lastIdx := 0
284
+
285
+ for i := 0 ; i < lenQuery ; i ++ {
286
+ currentChar := query [i ]
287
+ if state == stateEscape && ! ((currentChar == QUOTE_BYTE && singleQuotes ) || (currentChar == DBL_QUOTE_BYTE && ! singleQuotes )) {
288
+ state = stateString
289
+ lastChar = currentChar
279
290
continue
280
291
}
281
-
282
- switch v := arg .(type ) {
283
- case int64 :
284
- buf = strconv .AppendInt (buf , v , 10 )
285
- case uint64 :
286
- // Handle uint64 explicitly because our custom ConvertValue emits unsigned values
287
- buf = strconv .AppendUint (buf , v , 10 )
288
- case float64 :
289
- buf = strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
290
- case bool :
291
- if v {
292
- buf = append (buf , '1' )
293
- } else {
294
- buf = append (buf , '0' )
292
+ switch currentChar {
293
+ case STAR_BYTE :
294
+ if state == stateNormal && lastChar == SLASH_BYTE {
295
+ state = stateSlashStarComment
295
296
}
296
- case time.Time :
297
- if v .IsZero () {
298
- buf = append (buf , "'0000-00-00'" ... )
299
- } else {
300
- buf = append (buf , '\'' )
301
- buf , err = appendDateTime (buf , v .In (mc .cfg .Loc ), mc .cfg .timeTruncate )
302
- if err != nil {
303
- return "" , err
304
- }
305
- buf = append (buf , '\'' )
297
+ case SLASH_BYTE :
298
+ if state == stateSlashStarComment && lastChar == STAR_BYTE {
299
+ state = stateNormal
306
300
}
307
- case json.RawMessage :
308
- buf = append (buf , '\'' )
309
- if mc .status & statusNoBackslashEscapes == 0 {
310
- buf = escapeBytesBackslash (buf , v )
311
- } else {
312
- buf = escapeBytesQuotes (buf , v )
301
+ case HASH_BYTE :
302
+ if state == stateNormal {
303
+ state = stateEOLComment
313
304
}
314
- buf = append (buf , '\'' )
315
- case []byte :
316
- if v == nil {
317
- buf = append (buf , "NULL" ... )
318
- } else {
319
- buf = append (buf , "_binary'" ... )
320
- if mc .status & statusNoBackslashEscapes == 0 {
321
- buf = escapeBytesBackslash (buf , v )
322
- } else {
323
- buf = escapeBytesQuotes (buf , v )
324
- }
325
- buf = append (buf , '\'' )
305
+ case MINUS_BYTE :
306
+ if state == stateNormal && lastChar == MINUS_BYTE {
307
+ state = stateEOLComment
326
308
}
327
- case string :
328
- buf = append (buf , '\'' )
329
- if mc .status & statusNoBackslashEscapes == 0 {
330
- buf = escapeStringBackslash (buf , v )
331
- } else {
332
- buf = escapeStringQuotes (buf , v )
309
+ case LINE_FEED_BYTE :
310
+ if state == stateEOLComment {
311
+ state = stateNormal
333
312
}
334
- buf = append (buf , '\'' )
335
- default :
336
- return "" , driver .ErrSkip
337
- }
313
+ case DBL_QUOTE_BYTE :
314
+ if state == stateNormal {
315
+ state = stateString
316
+ singleQuotes = false
317
+ } else if state == stateString && ! singleQuotes {
318
+ state = stateNormal
319
+ } else if state == stateEscape {
320
+ state = stateString
321
+ }
322
+ case QUOTE_BYTE :
323
+ if state == stateNormal {
324
+ state = stateString
325
+ singleQuotes = true
326
+ } else if state == stateString && singleQuotes {
327
+ state = stateNormal
328
+ } else if state == stateEscape {
329
+ state = stateString
330
+ }
331
+ case BACKSLASH_BYTE :
332
+ if state == stateString && ! noBackslashEscapes {
333
+ state = stateEscape
334
+ }
335
+ case QUESTION_MARK_BYTE :
336
+ if state == stateNormal {
337
+ if argPos >= len (args ) {
338
+ return "" , driver .ErrSkip
339
+ }
340
+ buf = append (buf , query [lastIdx :i ]... )
341
+ arg := args [argPos ]
342
+ argPos ++
343
+
344
+ if arg == nil {
345
+ buf = append (buf , "NULL" ... )
346
+ lastIdx = i + 1
347
+ break
348
+ }
349
+
350
+ switch v := arg .(type ) {
351
+ case int64 :
352
+ buf = strconv .AppendInt (buf , v , 10 )
353
+ case uint64 :
354
+ buf = strconv .AppendUint (buf , v , 10 )
355
+ case float64 :
356
+ buf = strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
357
+ case bool :
358
+ if v {
359
+ buf = append (buf , '1' )
360
+ } else {
361
+ buf = append (buf , '0' )
362
+ }
363
+ case time.Time :
364
+ if v .IsZero () {
365
+ buf = append (buf , "'0000-00-00'" ... )
366
+ } else {
367
+ buf = append (buf , '\'' )
368
+ buf , err = appendDateTime (buf , v .In (mc .cfg .Loc ), mc .cfg .timeTruncate )
369
+ if err != nil {
370
+ return "" , err
371
+ }
372
+ buf = append (buf , '\'' )
373
+ }
374
+ case json.RawMessage :
375
+ if noBackslashEscapes {
376
+ buf = escapeBytesQuotes (buf , v , false )
377
+ } else {
378
+ buf = escapeBytesBackslash (buf , v , false )
379
+ }
380
+ case []byte :
381
+ if v == nil {
382
+ buf = append (buf , "NULL" ... )
383
+ } else {
384
+ if noBackslashEscapes {
385
+ buf = escapeBytesQuotes (buf , v , true )
386
+ } else {
387
+ buf = escapeBytesBackslash (buf , v , true )
388
+ }
389
+ }
390
+ case string :
391
+ if noBackslashEscapes {
392
+ buf = escapeStringQuotes (buf , v )
393
+ } else {
394
+ buf = escapeStringBackslash (buf , v )
395
+ }
396
+ default :
397
+ return "" , driver .ErrSkip
398
+ }
338
399
339
- if len (buf )+ 4 > mc .maxAllowedPacket {
340
- return "" , driver .ErrSkip
400
+ if len (buf )+ 4 > mc .maxAllowedPacket {
401
+ return "" , driver .ErrSkip
402
+ }
403
+ lastIdx = i + 1
404
+ }
405
+ case RADICAL_BYTE :
406
+ if state == stateBacktick {
407
+ state = stateNormal
408
+ } else if state == stateNormal {
409
+ state = stateBacktick
410
+ }
341
411
}
412
+ lastChar = currentChar
342
413
}
414
+ buf = append (buf , query [lastIdx :]... )
343
415
if argPos != len (args ) {
344
416
return "" , driver .ErrSkip
345
417
}
0 commit comments