@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172}
173173
174174// Closes the network connection and unsets internal variables. Do not call this
175- // function after successfully authentication, call Close instead. This function
175+ // function after successful authentication, call Close instead. This function
176176// is called before auth or on auth failure because MySQL will have already
177177// closed the network connection.
178178func (mc * mysqlConn ) cleanup () {
@@ -246,100 +246,172 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
246246}
247247
248248func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
249- // Number of ? should be same to len(args)
250- if strings .Count (query , "?" ) != len (args ) {
251- return "" , driver .ErrSkip
252- }
249+ noBackslashEscapes := (mc .status & statusNoBackslashEscapes ) != 0
250+ const (
251+ stateNormal = iota
252+ stateString
253+ stateEscape
254+ stateEOLComment
255+ stateSlashStarComment
256+ stateBacktick
257+ )
258+ 259+ var (
260+ QUOTE_BYTE = byte ('\'' )
261+ DBL_QUOTE_BYTE = byte ('"' )
262+ BACKSLASH_BYTE = byte ('\\' )
263+ QUESTION_MARK_BYTE = byte ('?' )
264+ SLASH_BYTE = byte ('/' )
265+ STAR_BYTE = byte ('*' )
266+ HASH_BYTE = byte ('#' )
267+ MINUS_BYTE = byte ('-' )
268+ LINE_FEED_BYTE = byte ('\n' )
269+ RADICAL_BYTE = byte ('`' )
270+ )
253271
254272 buf , err := mc .buf .takeCompleteBuffer ()
255273 if err != nil {
256- // can not take the buffer. Something must be wrong with the connection
257274 mc .cleanup ()
258- // interpolateParams would be called before sending any query.
259- // So its safe to retry.
260275 return "" , driver .ErrBadConn
261276 }
262277 buf = buf [:0 ]
278+ state := stateNormal
279+ singleQuotes := false
280+ lastChar := byte (0 )
263281 argPos := 0
264- 265- for i := 0 ; i < len (query ); i ++ {
266- q := strings .IndexByte (query [i :], '?' )
267- if q == - 1 {
268- buf = append (buf , query [i :]... )
269- break
270- }
271- buf = append (buf , query [i :i + q ]... )
272- i += q
273- 274- arg := args [argPos ]
275- argPos ++
276- 277- if arg == nil {
278- buf = append (buf , "NULL" ... )
282+ lenQuery := len (query )
283+ lastIdx := 0
284+ 285+ for i := 0 ; i < lenQuery ; i ++ {
286+ currentChar := query [i ]
287+ if state == stateEscape && ! ((currentChar == QUOTE_BYTE && singleQuotes ) || (currentChar == DBL_QUOTE_BYTE && ! singleQuotes )) {
288+ state = stateString
289+ lastChar = currentChar
279290 continue
280291 }
281- 282- switch v := arg .(type ) {
283- case int64 :
284- buf = strconv .AppendInt (buf , v , 10 )
285- case uint64 :
286- // Handle uint64 explicitly because our custom ConvertValue emits unsigned values
287- buf = strconv .AppendUint (buf , v , 10 )
288- case float64 :
289- buf = strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
290- case bool :
291- if v {
292- buf = append (buf , '1' )
293- } else {
294- buf = append (buf , '0' )
292+ switch currentChar {
293+ case STAR_BYTE :
294+ if state == stateNormal && lastChar == SLASH_BYTE {
295+ state = stateSlashStarComment
295296 }
296- case time.Time :
297- if v .IsZero () {
298- buf = append (buf , "'0000-00-00'" ... )
299- } else {
300- buf = append (buf , '\'' )
301- buf , err = appendDateTime (buf , v .In (mc .cfg .Loc ), mc .cfg .timeTruncate )
302- if err != nil {
303- return "" , err
304- }
305- buf = append (buf , '\'' )
297+ case SLASH_BYTE :
298+ if state == stateSlashStarComment && lastChar == STAR_BYTE {
299+ state = stateNormal
306300 }
307- case json.RawMessage :
308- buf = append (buf , '\'' )
309- if mc .status & statusNoBackslashEscapes == 0 {
310- buf = escapeBytesBackslash (buf , v )
311- } else {
312- buf = escapeBytesQuotes (buf , v )
301+ case HASH_BYTE :
302+ if state == stateNormal {
303+ state = stateEOLComment
313304 }
314- buf = append (buf , '\'' )
315- case []byte :
316- if v == nil {
317- buf = append (buf , "NULL" ... )
318- } else {
319- buf = append (buf , "_binary'" ... )
320- if mc .status & statusNoBackslashEscapes == 0 {
321- buf = escapeBytesBackslash (buf , v )
322- } else {
323- buf = escapeBytesQuotes (buf , v )
324- }
325- buf = append (buf , '\'' )
305+ case MINUS_BYTE :
306+ if state == stateNormal && lastChar == MINUS_BYTE {
307+ state = stateEOLComment
326308 }
327- case string :
328- buf = append (buf , '\'' )
329- if mc .status & statusNoBackslashEscapes == 0 {
330- buf = escapeStringBackslash (buf , v )
331- } else {
332- buf = escapeStringQuotes (buf , v )
309+ case LINE_FEED_BYTE :
310+ if state == stateEOLComment {
311+ state = stateNormal
333312 }
334- buf = append (buf , '\'' )
335- default :
336- return "" , driver .ErrSkip
337- }
313+ case DBL_QUOTE_BYTE :
314+ if state == stateNormal {
315+ state = stateString
316+ singleQuotes = false
317+ } else if state == stateString && ! singleQuotes {
318+ state = stateNormal
319+ } else if state == stateEscape {
320+ state = stateString
321+ }
322+ case QUOTE_BYTE :
323+ if state == stateNormal {
324+ state = stateString
325+ singleQuotes = true
326+ } else if state == stateString && singleQuotes {
327+ state = stateNormal
328+ } else if state == stateEscape {
329+ state = stateString
330+ }
331+ case BACKSLASH_BYTE :
332+ if state == stateString && ! noBackslashEscapes {
333+ state = stateEscape
334+ }
335+ case QUESTION_MARK_BYTE :
336+ if state == stateNormal {
337+ if argPos >= len (args ) {
338+ return "" , driver .ErrSkip
339+ }
340+ buf = append (buf , query [lastIdx :i ]... )
341+ arg := args [argPos ]
342+ argPos ++
343+ 344+ if arg == nil {
345+ buf = append (buf , "NULL" ... )
346+ lastIdx = i + 1
347+ break
348+ }
349+ 350+ switch v := arg .(type ) {
351+ case int64 :
352+ buf = strconv .AppendInt (buf , v , 10 )
353+ case uint64 :
354+ buf = strconv .AppendUint (buf , v , 10 )
355+ case float64 :
356+ buf = strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
357+ case bool :
358+ if v {
359+ buf = append (buf , '1' )
360+ } else {
361+ buf = append (buf , '0' )
362+ }
363+ case time.Time :
364+ if v .IsZero () {
365+ buf = append (buf , "'0000-00-00'" ... )
366+ } else {
367+ buf = append (buf , '\'' )
368+ buf , err = appendDateTime (buf , v .In (mc .cfg .Loc ), mc .cfg .timeTruncate )
369+ if err != nil {
370+ return "" , err
371+ }
372+ buf = append (buf , '\'' )
373+ }
374+ case json.RawMessage :
375+ if noBackslashEscapes {
376+ buf = escapeBytesQuotes (buf , v , false )
377+ } else {
378+ buf = escapeBytesBackslash (buf , v , false )
379+ }
380+ case []byte :
381+ if v == nil {
382+ buf = append (buf , "NULL" ... )
383+ } else {
384+ if noBackslashEscapes {
385+ buf = escapeBytesQuotes (buf , v , true )
386+ } else {
387+ buf = escapeBytesBackslash (buf , v , true )
388+ }
389+ }
390+ case string :
391+ if noBackslashEscapes {
392+ buf = escapeStringQuotes (buf , v )
393+ } else {
394+ buf = escapeStringBackslash (buf , v )
395+ }
396+ default :
397+ return "" , driver .ErrSkip
398+ }
338399
339- if len (buf )+ 4 > mc .maxAllowedPacket {
340- return "" , driver .ErrSkip
400+ if len (buf )+ 4 > mc .maxAllowedPacket {
401+ return "" , driver .ErrSkip
402+ }
403+ lastIdx = i + 1
404+ }
405+ case RADICAL_BYTE :
406+ if state == stateBacktick {
407+ state = stateNormal
408+ } else if state == stateNormal {
409+ state = stateBacktick
410+ }
341411 }
412+ lastChar = currentChar
342413 }
414+ buf = append (buf , query [lastIdx :]... )
343415 if argPos != len (args ) {
344416 return "" , driver .ErrSkip
345417 }
0 commit comments