Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 883c78c

Browse files
committed
Enhance interpolateParams to correctly handle placeholders in queries with comments, strings, and backticks.
* Add `findParamPositions` to identify real parameter positions * Update and expand related tests.
1 parent 76c00e3 commit 883c78c

File tree

2 files changed

+154
-30
lines changed

2 files changed

+154
-30
lines changed

‎connection.go

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172
}
173173

174174
// Closes the network connection and unsets internal variables. Do not call this
175-
// function after successfully authentication, call Close instead. This function
175+
// function after successful authentication, call Close instead. This function
176176
// is called before auth or on auth failure because MySQL will have already
177177
// closed the network connection.
178178
func (mc *mysqlConn) cleanup() {
@@ -245,9 +245,106 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
245245
return stmt, err
246246
}
247247

248+
// findParamPositions returns the positions of real parameter holders ('?') in the query, ignoring those in comments, strings, or backticks.
249+
func findParamPositions(query string) []int {
250+
const (
251+
stateNormal = iota
252+
stateString
253+
stateEscape
254+
stateEOLComment
255+
stateSlashStarComment
256+
stateBacktick
257+
)
258+
259+
var (
260+
QUOTE_BYTE = byte('\'')
261+
DBL_QUOTE_BYTE = byte('"')
262+
BACKSLASH_BYTE = byte('\\')
263+
QUESTION_MARK_BYTE = byte('?')
264+
SLASH_BYTE = byte('/')
265+
STAR_BYTE = byte('*')
266+
HASH_BYTE = byte('#')
267+
MINUS_BYTE = byte('-')
268+
LINE_FEED_BYTE = byte('\n')
269+
RADICAL_BYTE = byte('`')
270+
)
271+
272+
paramPositions := make([]int, 0)
273+
state := stateNormal
274+
singleQuotes := false
275+
lastChar := byte(0)
276+
lenq := len(query)
277+
for i := 0; i < lenq; i++ {
278+
currentChar := query[i]
279+
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
280+
state = stateString
281+
lastChar = currentChar
282+
continue
283+
}
284+
switch currentChar {
285+
case STAR_BYTE:
286+
if state == stateNormal && lastChar == SLASH_BYTE {
287+
state = stateSlashStarComment
288+
}
289+
case SLASH_BYTE:
290+
if state == stateSlashStarComment && lastChar == STAR_BYTE {
291+
state = stateNormal
292+
} else if state == stateNormal && lastChar == SLASH_BYTE {
293+
state = stateEOLComment
294+
}
295+
case HASH_BYTE:
296+
if state == stateNormal {
297+
state = stateEOLComment
298+
}
299+
case MINUS_BYTE:
300+
if state == stateNormal && lastChar == MINUS_BYTE {
301+
state = stateEOLComment
302+
}
303+
case LINE_FEED_BYTE:
304+
if state == stateEOLComment {
305+
state = stateNormal
306+
}
307+
case DBL_QUOTE_BYTE:
308+
if state == stateNormal {
309+
state = stateString
310+
singleQuotes = false
311+
} else if state == stateString && !singleQuotes {
312+
state = stateNormal
313+
} else if state == stateEscape {
314+
state = stateString
315+
}
316+
case QUOTE_BYTE:
317+
if state == stateNormal {
318+
state = stateString
319+
singleQuotes = true
320+
} else if state == stateString && singleQuotes {
321+
state = stateNormal
322+
} else if state == stateEscape {
323+
state = stateString
324+
}
325+
case BACKSLASH_BYTE:
326+
if state == stateString {
327+
state = stateEscape
328+
}
329+
case QUESTION_MARK_BYTE:
330+
if state == stateNormal {
331+
paramPositions = append(paramPositions, i)
332+
}
333+
case RADICAL_BYTE:
334+
if state == stateBacktick {
335+
state = stateNormal
336+
} else if state == stateNormal {
337+
state = stateBacktick
338+
}
339+
}
340+
lastChar = currentChar
341+
}
342+
return paramPositions
343+
}
344+
248345
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
249-
// Number of ? should be same to len(args)
250-
if strings.Count(query, "?") != len(args) {
346+
paramPositions:=findParamPositions(query)
347+
if len(paramPositions) != len(args) {
251348
return "", driver.ErrSkip
252349
}
253350

@@ -261,21 +358,16 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
261358
}
262359
buf = buf[:0]
263360
argPos := 0
361+
lastIdx := 0
264362

265-
for i := 0; i < len(query); i++ {
266-
q := strings.IndexByte(query[i:], '?')
267-
if q == -1 {
268-
buf = append(buf, query[i:]...)
269-
break
270-
}
271-
buf = append(buf, query[i:i+q]...)
272-
i += q
273-
363+
for _, qmIdx := range paramPositions {
364+
buf = append(buf, query[lastIdx:qmIdx]...)
274365
arg := args[argPos]
275366
argPos++
276367

277368
if arg == nil {
278369
buf = append(buf, "NULL"...)
370+
lastIdx = qmIdx + 1
279371
continue
280372
}
281373

@@ -339,7 +431,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
339431
if len(buf)+4 > mc.maxAllowedPacket {
340432
return "", driver.ErrSkip
341433
}
434+
lastIdx = qmIdx + 1
342435
}
436+
buf = append(buf, query[lastIdx:]...)
343437
if argPos != len(args) {
344438
return "", driver.ErrSkip
345439
}

‎connection_test.go

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
7979
}
8080
}
8181

82-
// We don't support placeholder in string literal for now.
83-
// https://github.com/go-sql-driver/mysql/pull/490
84-
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
85-
mc := &mysqlConn{
86-
buf: newBuffer(),
87-
maxAllowedPacket: maxPacketSize,
88-
cfg: &Config{
89-
InterpolateParams: true,
90-
},
91-
}
92-
93-
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
94-
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
95-
if err != driver.ErrSkip {
96-
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
97-
}
98-
}
99-
10082
func TestInterpolateParamsUint64(t *testing.T) {
10183
mc := &mysqlConn{
10284
buf: newBuffer(),
@@ -204,3 +186,51 @@ func (bc badConnection) Write(b []byte) (n int, err error) {
204186
func (bc badConnection) Close() error {
205187
return nil
206188
}
189+
190+
func TestInterpolateParamsWithComments(t *testing.T) {
191+
mc := &mysqlConn{
192+
buf: newBuffer(),
193+
maxAllowedPacket: maxPacketSize,
194+
cfg: &Config{
195+
InterpolateParams: true,
196+
},
197+
}
198+
199+
tests := []struct {
200+
query string
201+
args []driver.Value
202+
expected string
203+
shouldSkip bool
204+
}{
205+
// ? in single-line comment (--) should not be replaced
206+
{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
207+
// ? in single-line comment (#) should not be replaced
208+
{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
209+
// ? in multi-line comment should not be replaced
210+
{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
211+
// ? in string literal should not be replaced
212+
{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
213+
// ? in backtick identifier should not be replaced
214+
{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
215+
// Multiple comments and real placeholders
216+
{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
217+
}
218+
219+
for i, test := range tests {
220+
221+
q, err := mc.interpolateParams(test.query, test.args)
222+
if test.shouldSkip {
223+
if err != driver.ErrSkip {
224+
t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
225+
}
226+
continue
227+
}
228+
if err != nil {
229+
t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
230+
continue
231+
}
232+
if q != test.expected {
233+
t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
234+
}
235+
}
236+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /