|
| 1 | +use std::collections::HashMap; |
1 | 2 | use std::num::NonZeroUsize;
|
2 | 3 | use std::sync::{Arc, Mutex};
|
3 | 4 |
|
4 | 5 | use lru::LruCache;
|
5 | 6 | use pgt_query_ext::diagnostics::*;
|
6 | 7 | use pgt_text_size::TextRange;
|
| 8 | +use pgt_tokenizer::tokenize; |
7 | 9 |
|
8 | 10 | use super::statement_identifier::StatementId;
|
9 | 11 |
|
@@ -37,7 +39,7 @@ impl PgQueryStore {
|
37 | 39 | }
|
38 | 40 |
|
39 | 41 | let r = Arc::new(
|
40 | | - pgt_query::parse(statement.content()) |
| 42 | + pgt_query::parse(&convert_to_positional_params(statement.content())) |
41 | 43 | .map_err(SyntaxDiagnostic::from)
|
42 | 44 | .and_then(|ast| {
|
43 | 45 | ast.into_root().ok_or_else(|| {
|
@@ -87,10 +89,79 @@ impl PgQueryStore {
|
87 | 89 | }
|
88 | 90 | }
|
89 | 91 |
|
| 92 | +/// Converts named parameters in a SQL query string to positional parameters. |
| 93 | +/// |
| 94 | +/// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`) |
| 95 | +/// and replaces them with positional parameters (e.g., `1ドル`, `2ドル`, etc.). |
| 96 | +/// |
| 97 | +/// It maintains the original spacing of the named parameters in the output string. |
| 98 | +/// |
| 99 | +/// Useful for preparing SQL queries for parsing or execution where named paramters are not supported. |
| 100 | +pub fn convert_to_positional_params(text: &str) -> String { |
| 101 | + let mut result = String::with_capacity(text.len()); |
| 102 | + let mut param_mapping: HashMap<&str, usize> = HashMap::new(); |
| 103 | + let mut param_index = 1; |
| 104 | + let mut position = 0; |
| 105 | + |
| 106 | + for token in tokenize(text) { |
| 107 | + let token_len = token.len as usize; |
| 108 | + let token_text = &text[position..position + token_len]; |
| 109 | + |
| 110 | + if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) { |
| 111 | + let idx = match param_mapping.get(token_text) { |
| 112 | + Some(&index) => index, |
| 113 | + None => { |
| 114 | + let index = param_index; |
| 115 | + param_mapping.insert(token_text, index); |
| 116 | + param_index += 1; |
| 117 | + index |
| 118 | + } |
| 119 | + }; |
| 120 | + |
| 121 | + let replacement = format!("${}", idx); |
| 122 | + let original_len = token_text.len(); |
| 123 | + let replacement_len = replacement.len(); |
| 124 | + |
| 125 | + result.push_str(&replacement); |
| 126 | + |
| 127 | + // maintain original spacing |
| 128 | + if replacement_len < original_len { |
| 129 | + result.push_str(&" ".repeat(original_len - replacement_len)); |
| 130 | + } |
| 131 | + } else { |
| 132 | + result.push_str(token_text); |
| 133 | + } |
| 134 | + |
| 135 | + position += token_len; |
| 136 | + } |
| 137 | + |
| 138 | + result |
| 139 | +} |
| 140 | + |
90 | 141 | #[cfg(test)]
|
91 | 142 | mod tests {
|
92 | 143 | use super::*;
|
93 | 144 |
|
| 145 | + #[test] |
| 146 | + fn test_convert_to_positional_params() { |
| 147 | + let input = "select * from users where id = @one and name = :two and email = :'three';"; |
| 148 | + let result = convert_to_positional_params(input); |
| 149 | + assert_eq!( |
| 150 | + result, |
| 151 | + "select * from users where id = 1ドル and name = 2ドル and email = 3ドル ;" |
| 152 | + ); |
| 153 | + } |
| 154 | + |
| 155 | + #[test] |
| 156 | + fn test_convert_to_positional_params_with_duplicates() { |
| 157 | + let input = "select * from users where first_name = @one and starts_with(email, @one) and created_at > @two;"; |
| 158 | + let result = convert_to_positional_params(input); |
| 159 | + assert_eq!( |
| 160 | + result, |
| 161 | + "select * from users where first_name = 1ドル and starts_with(email, 1ドル ) and created_at > 2ドル ;" |
| 162 | + ); |
| 163 | + } |
| 164 | + |
94 | 165 | #[test]
|
95 | 166 | fn test_plpgsql_syntax_error() {
|
96 | 167 | let input = "
|
|
0 commit comments