Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 66edfad

Browse files
committed
nhance interpolateParams to correctly handle placeholders in queries with comments, strings, and backticks.
* Add `findParamPositions` to identify real parameter positions * Update and expand related tests.
1 parent 76c00e3 commit 66edfad

File tree

4 files changed

+313
-197
lines changed

4 files changed

+313
-197
lines changed

‎connection.go

Lines changed: 149 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172
}
173173

174174
// Closes the network connection and unsets internal variables. Do not call this
175-
// function after successfully authentication, call Close instead. This function
175+
// function after successful authentication, call Close instead. This function
176176
// is called before auth or on auth failure because MySQL will have already
177177
// closed the network connection.
178178
func (mc *mysqlConn) cleanup() {
@@ -246,100 +246,172 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
246246
}
247247

248248
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
249-
// Number of ? should be same to len(args)
250-
if strings.Count(query, "?") != len(args) {
251-
return "", driver.ErrSkip
252-
}
249+
noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
250+
const (
251+
stateNormal = iota
252+
stateString
253+
stateEscape
254+
stateEOLComment
255+
stateSlashStarComment
256+
stateBacktick
257+
)
258+
259+
var (
260+
QUOTE_BYTE = byte('\'')
261+
DBL_QUOTE_BYTE = byte('"')
262+
BACKSLASH_BYTE = byte('\\')
263+
QUESTION_MARK_BYTE = byte('?')
264+
SLASH_BYTE = byte('/')
265+
STAR_BYTE = byte('*')
266+
HASH_BYTE = byte('#')
267+
MINUS_BYTE = byte('-')
268+
LINE_FEED_BYTE = byte('\n')
269+
RADICAL_BYTE = byte('`')
270+
)
253271

254272
buf, err := mc.buf.takeCompleteBuffer()
255273
if err != nil {
256-
// can not take the buffer. Something must be wrong with the connection
257274
mc.cleanup()
258-
// interpolateParams would be called before sending any query.
259-
// So its safe to retry.
260275
return "", driver.ErrBadConn
261276
}
262277
buf = buf[:0]
278+
state := stateNormal
279+
singleQuotes := false
280+
lastChar := byte(0)
263281
argPos := 0
264-
265-
for i := 0; i < len(query); i++ {
266-
q := strings.IndexByte(query[i:], '?')
267-
if q == -1 {
268-
buf = append(buf, query[i:]...)
269-
break
270-
}
271-
buf = append(buf, query[i:i+q]...)
272-
i += q
273-
274-
arg := args[argPos]
275-
argPos++
276-
277-
if arg == nil {
278-
buf = append(buf, "NULL"...)
282+
lenQuery := len(query)
283+
lastIdx := 0
284+
285+
for i := 0; i < lenQuery; i++ {
286+
currentChar := query[i]
287+
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
288+
state = stateString
289+
lastChar = currentChar
279290
continue
280291
}
281-
282-
switch v := arg.(type) {
283-
case int64:
284-
buf = strconv.AppendInt(buf, v, 10)
285-
case uint64:
286-
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
287-
buf = strconv.AppendUint(buf, v, 10)
288-
case float64:
289-
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
290-
case bool:
291-
if v {
292-
buf = append(buf, '1')
293-
} else {
294-
buf = append(buf, '0')
292+
switch currentChar {
293+
case STAR_BYTE:
294+
if state == stateNormal && lastChar == SLASH_BYTE {
295+
state = stateSlashStarComment
295296
}
296-
case time.Time:
297-
if v.IsZero() {
298-
buf = append(buf, "'0000-00-00'"...)
299-
} else {
300-
buf = append(buf, '\'')
301-
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
302-
if err != nil {
303-
return "", err
304-
}
305-
buf = append(buf, '\'')
297+
case SLASH_BYTE:
298+
if state == stateSlashStarComment && lastChar == STAR_BYTE {
299+
state = stateNormal
306300
}
307-
case json.RawMessage:
308-
buf = append(buf, '\'')
309-
if mc.status&statusNoBackslashEscapes == 0 {
310-
buf = escapeBytesBackslash(buf, v)
311-
} else {
312-
buf = escapeBytesQuotes(buf, v)
301+
case HASH_BYTE:
302+
if state == stateNormal {
303+
state = stateEOLComment
313304
}
314-
buf = append(buf, '\'')
315-
case []byte:
316-
if v == nil {
317-
buf = append(buf, "NULL"...)
318-
} else {
319-
buf = append(buf, "_binary'"...)
320-
if mc.status&statusNoBackslashEscapes == 0 {
321-
buf = escapeBytesBackslash(buf, v)
322-
} else {
323-
buf = escapeBytesQuotes(buf, v)
324-
}
325-
buf = append(buf, '\'')
305+
case MINUS_BYTE:
306+
if state == stateNormal && lastChar == MINUS_BYTE {
307+
state = stateEOLComment
326308
}
327-
case string:
328-
buf = append(buf, '\'')
329-
if mc.status&statusNoBackslashEscapes == 0 {
330-
buf = escapeStringBackslash(buf, v)
331-
} else {
332-
buf = escapeStringQuotes(buf, v)
309+
case LINE_FEED_BYTE:
310+
if state == stateEOLComment {
311+
state = stateNormal
333312
}
334-
buf = append(buf, '\'')
335-
default:
336-
return "", driver.ErrSkip
337-
}
313+
case DBL_QUOTE_BYTE:
314+
if state == stateNormal {
315+
state = stateString
316+
singleQuotes = false
317+
} else if state == stateString && !singleQuotes {
318+
state = stateNormal
319+
} else if state == stateEscape {
320+
state = stateString
321+
}
322+
case QUOTE_BYTE:
323+
if state == stateNormal {
324+
state = stateString
325+
singleQuotes = true
326+
} else if state == stateString && singleQuotes {
327+
state = stateNormal
328+
} else if state == stateEscape {
329+
state = stateString
330+
}
331+
case BACKSLASH_BYTE:
332+
if state == stateString && !noBackslashEscapes {
333+
state = stateEscape
334+
}
335+
case QUESTION_MARK_BYTE:
336+
if state == stateNormal {
337+
if argPos >= len(args) {
338+
return "", driver.ErrSkip
339+
}
340+
buf = append(buf, query[lastIdx:i]...)
341+
arg := args[argPos]
342+
argPos++
343+
344+
if arg == nil {
345+
buf = append(buf, "NULL"...)
346+
lastIdx = i + 1
347+
break
348+
}
349+
350+
switch v := arg.(type) {
351+
case int64:
352+
buf = strconv.AppendInt(buf, v, 10)
353+
case uint64:
354+
buf = strconv.AppendUint(buf, v, 10)
355+
case float64:
356+
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
357+
case bool:
358+
if v {
359+
buf = append(buf, '1')
360+
} else {
361+
buf = append(buf, '0')
362+
}
363+
case time.Time:
364+
if v.IsZero() {
365+
buf = append(buf, "'0000-00-00'"...)
366+
} else {
367+
buf = append(buf, '\'')
368+
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
369+
if err != nil {
370+
return "", err
371+
}
372+
buf = append(buf, '\'')
373+
}
374+
case json.RawMessage:
375+
if noBackslashEscapes {
376+
buf = escapeBytesQuotes(buf, v, false)
377+
} else {
378+
buf = escapeBytesBackslash(buf, v, false)
379+
}
380+
case []byte:
381+
if v == nil {
382+
buf = append(buf, "NULL"...)
383+
} else {
384+
if noBackslashEscapes {
385+
buf = escapeBytesQuotes(buf, v, true)
386+
} else {
387+
buf = escapeBytesBackslash(buf, v, true)
388+
}
389+
}
390+
case string:
391+
if noBackslashEscapes {
392+
buf = escapeStringQuotes(buf, v)
393+
} else {
394+
buf = escapeStringBackslash(buf, v)
395+
}
396+
default:
397+
return "", driver.ErrSkip
398+
}
338399

339-
if len(buf)+4 > mc.maxAllowedPacket {
340-
return "", driver.ErrSkip
400+
if len(buf)+4 > mc.maxAllowedPacket {
401+
return "", driver.ErrSkip
402+
}
403+
lastIdx = i + 1
404+
}
405+
case RADICAL_BYTE:
406+
if state == stateBacktick {
407+
state = stateNormal
408+
} else if state == stateNormal {
409+
state = stateBacktick
410+
}
341411
}
412+
lastChar = currentChar
342413
}
414+
buf = append(buf, query[lastIdx:]...)
343415
if argPos != len(args) {
344416
return "", driver.ErrSkip
345417
}

‎connection_test.go

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
7979
}
8080
}
8181

82-
// We don't support placeholder in string literal for now.
83-
// https://github.com/go-sql-driver/mysql/pull/490
84-
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
85-
mc := &mysqlConn{
86-
buf: newBuffer(),
87-
maxAllowedPacket: maxPacketSize,
88-
cfg: &Config{
89-
InterpolateParams: true,
90-
},
91-
}
92-
93-
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
94-
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
95-
if err != driver.ErrSkip {
96-
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
97-
}
98-
}
99-
10082
func TestInterpolateParamsUint64(t *testing.T) {
10183
mc := &mysqlConn{
10284
buf: newBuffer(),
@@ -204,3 +186,55 @@ func (bc badConnection) Write(b []byte) (n int, err error) {
204186
func (bc badConnection) Close() error {
205187
return nil
206188
}
189+
190+
func TestInterpolateParamsWithComments(t *testing.T) {
191+
mc := &mysqlConn{
192+
buf: newBuffer(),
193+
maxAllowedPacket: maxPacketSize,
194+
cfg: &Config{
195+
InterpolateParams: true,
196+
},
197+
}
198+
199+
tests := []struct {
200+
query string
201+
args []driver.Value
202+
expected string
203+
shouldSkip bool
204+
}{
205+
// ? in single-line comment (--) should not be replaced
206+
{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
207+
// ? in single-line comment (#) should not be replaced
208+
{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
209+
// ? in multi-line comment should not be replaced
210+
{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
211+
// ? in string literal should not be replaced
212+
{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
213+
// ? in backtick identifier should not be replaced
214+
{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
215+
// ? in backslash-escaped string literal should not be replaced
216+
{"SELECT 'C:\\path\\?x.txt', ?", []driver.Value{int64(42)}, "SELECT 'C:\\path\\?x.txt', 42", false},
217+
// ? in backslash-escaped string literal should not be replaced
218+
{"SELECT '\\'?', col FROM tbl WHERE id = ? AND desc = 'foo\\'bar?'", []driver.Value{int64(42)}, "SELECT '\\'?', col FROM tbl WHERE id = 42 AND desc = 'foo\\'bar?'", false},
219+
// Multiple comments and real placeholders
220+
{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
221+
}
222+
223+
for i, test := range tests {
224+
225+
q, err := mc.interpolateParams(test.query, test.args)
226+
if test.shouldSkip {
227+
if err != driver.ErrSkip {
228+
t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
229+
}
230+
continue
231+
}
232+
if err != nil {
233+
t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
234+
continue
235+
}
236+
if q != test.expected {
237+
t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
238+
}
239+
}
240+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /