Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 224d7fd

Browse files
feat(completions): complete in WITH CHECK and USING clauses (#422)
1 parent 4cb12df commit 224d7fd

File tree

8 files changed

+504
-74
lines changed

8 files changed

+504
-74
lines changed

‎crates/pgt_completions/src/context/base_parser.rs‎

Lines changed: 121 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
use std::iter::Peekable;
2-
31
use pgt_text_size::{TextRange, TextSize};
2+
use std::iter::Peekable;
43

54
pub(crate) struct TokenNavigator {
65
tokens: Peekable<std::vec::IntoIter<WordWithIndex>>,
@@ -101,73 +100,139 @@ impl WordWithIndex {
101100
}
102101
}
103102

104-
/// Note: A policy name within quotation marks will be considered a single word.
105-
pub(crate) fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
106-
let mut words = vec![];
107-
108-
let mut start_of_word: Option<usize> = None;
109-
let mut current_word = String::new();
110-
let mut in_quotation_marks = false;
111-
112-
for (current_position, current_char) in sql.char_indices() {
113-
if (current_char.is_ascii_whitespace() || current_char == ';')
114-
&& !current_word.is_empty()
115-
&& start_of_word.is_some()
116-
&& !in_quotation_marks
117-
{
118-
words.push(WordWithIndex {
119-
word: current_word,
120-
start: start_of_word.unwrap(),
121-
end: current_position,
122-
});
123-
124-
current_word = String::new();
125-
start_of_word = None;
126-
} else if (current_char.is_ascii_whitespace() || current_char == ';')
127-
&& current_word.is_empty()
128-
{
129-
// do nothing
130-
} else if current_char == '"' && start_of_word.is_none() {
131-
in_quotation_marks = true;
132-
current_word.push(current_char);
133-
start_of_word = Some(current_position);
134-
} else if current_char == '"' && start_of_word.is_some() {
135-
current_word.push(current_char);
136-
in_quotation_marks = false;
137-
} else if start_of_word.is_some() {
138-
current_word.push(current_char)
103+
pub(crate) struct SubStatementParser {
104+
start_of_word: Option<usize>,
105+
current_word: String,
106+
in_quotation_marks: bool,
107+
is_fn_call: bool,
108+
words: Vec<WordWithIndex>,
109+
}
110+
111+
impl SubStatementParser {
112+
pub(crate) fn parse(sql: &str) -> Result<Vec<WordWithIndex>, String> {
113+
let mut parser = SubStatementParser {
114+
start_of_word: None,
115+
current_word: String::new(),
116+
in_quotation_marks: false,
117+
is_fn_call: false,
118+
words: vec![],
119+
};
120+
121+
parser.collect_words(sql);
122+
123+
if parser.in_quotation_marks {
124+
Err("String was not closed properly.".into())
139125
} else {
140-
start_of_word = Some(current_position);
141-
current_word.push(current_char);
126+
Ok(parser.words)
142127
}
143128
}
144129

145-
if let Some(start_of_word) = start_of_word {
146-
if !current_word.is_empty() {
147-
words.push(WordWithIndex {
148-
word: current_word,
149-
start: start_of_word,
150-
end: sql.len(),
151-
});
130+
pub fn collect_words(&mut self, sql: &str) {
131+
for (pos, c) in sql.char_indices() {
132+
match c {
133+
'"' => {
134+
if !self.has_started_word() {
135+
self.in_quotation_marks = true;
136+
self.add_char(c);
137+
self.start_word(pos);
138+
} else {
139+
self.in_quotation_marks = false;
140+
self.add_char(c);
141+
}
142+
}
143+
144+
'(' => {
145+
if !self.has_started_word() {
146+
self.push_char_as_word(c, pos);
147+
} else {
148+
self.add_char(c);
149+
self.is_fn_call = true;
150+
}
151+
}
152+
153+
')' => {
154+
if self.is_fn_call {
155+
self.add_char(c);
156+
self.is_fn_call = false;
157+
} else {
158+
if self.has_started_word() {
159+
self.push_word(pos);
160+
}
161+
self.push_char_as_word(c, pos);
162+
}
163+
}
164+
165+
_ => {
166+
if c.is_ascii_whitespace() || c == ';' {
167+
if self.in_quotation_marks {
168+
self.add_char(c);
169+
} else if !self.is_empty() && self.has_started_word() {
170+
self.push_word(pos);
171+
}
172+
} else if self.has_started_word() {
173+
self.add_char(c);
174+
} else {
175+
self.start_word(pos);
176+
self.add_char(c)
177+
}
178+
}
179+
}
180+
}
181+
182+
if self.has_started_word() && !self.is_empty() {
183+
self.push_word(sql.len())
152184
}
153185
}
154186

155-
if in_quotation_marks {
156-
Err("String was not closed properly.".into())
157-
} else {
158-
Ok(words)
187+
fn is_empty(&self) -> bool {
188+
self.current_word.is_empty()
189+
}
190+
191+
fn add_char(&mut self, c: char) {
192+
self.current_word.push(c)
193+
}
194+
195+
fn start_word(&mut self, pos: usize) {
196+
self.start_of_word = Some(pos);
197+
}
198+
199+
fn has_started_word(&self) -> bool {
200+
self.start_of_word.is_some()
201+
}
202+
203+
fn push_char_as_word(&mut self, c: char, pos: usize) {
204+
self.words.push(WordWithIndex {
205+
word: String::from(c),
206+
start: pos,
207+
end: pos + 1,
208+
});
209+
}
210+
211+
fn push_word(&mut self, current_position: usize) {
212+
self.words.push(WordWithIndex {
213+
word: self.current_word.clone(),
214+
start: self.start_of_word.unwrap(),
215+
end: current_position,
216+
});
217+
self.current_word = String::new();
218+
self.start_of_word = None;
159219
}
160220
}
161221

222+
/// Note: A policy name within quotation marks will be considered a single word.
223+
pub(crate) fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
224+
SubStatementParser::parse(sql)
225+
}
226+
162227
#[cfg(test)]
163228
mod tests {
164-
use crate::context::base_parser::{WordWithIndex, sql_to_words};
229+
use crate::context::base_parser::{SubStatementParser,WordWithIndex, sql_to_words};
165230

166231
#[test]
167232
fn determines_positions_correctly() {
168-
let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string();
233+
let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (auth.uid());".to_string();
169234

170-
let words = sql_to_words(query.as_str()).unwrap();
235+
let words = SubStatementParser::parse(query.as_str()).unwrap();
171236

172237
assert_eq!(words[0], to_word("create", 1, 7));
173238
assert_eq!(words[1], to_word("policy", 8, 14));
@@ -181,7 +246,9 @@ mod tests {
181246
assert_eq!(words[9], to_word("to", 73, 75));
182247
assert_eq!(words[10], to_word("public", 78, 84));
183248
assert_eq!(words[11], to_word("using", 87, 92));
184-
assert_eq!(words[12], to_word("(true)", 93, 99));
249+
assert_eq!(words[12], to_word("(", 93, 94));
250+
assert_eq!(words[13], to_word("auth.uid()", 94, 104));
251+
assert_eq!(words[14], to_word(")", 104, 105));
185252
}
186253

187254
#[test]

‎crates/pgt_completions/src/context/mod.rs‎

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ pub enum WrappingClause<'a> {
4747
SetStatement,
4848
AlterRole,
4949
DropRole,
50+
51+
/// `PolicyCheck` refers to either the `WITH CHECK` or the `USING` clause
52+
/// in a policy statement.
53+
/// ```sql
54+
/// CREATE POLICY "my pol" ON PUBLIC.USERS
55+
/// FOR SELECT
56+
/// USING (...) -- this one!
57+
/// ```
58+
PolicyCheck,
5059
}
5160

5261
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
@@ -78,6 +87,7 @@ pub(crate) enum NodeUnderCursor<'a> {
7887
text: NodeText,
7988
range: TextRange,
8089
kind: String,
90+
previous_node_kind: Option<String>,
8191
},
8292
}
8393

@@ -222,6 +232,7 @@ impl<'a> CompletionContext<'a> {
222232
text: revoke_context.node_text.into(),
223233
range: revoke_context.node_range,
224234
kind: revoke_context.node_kind.clone(),
235+
previous_node_kind: None,
225236
});
226237

227238
if revoke_context.node_kind == "revoke_table" {
@@ -249,6 +260,7 @@ impl<'a> CompletionContext<'a> {
249260
text: grant_context.node_text.into(),
250261
range: grant_context.node_range,
251262
kind: grant_context.node_kind.clone(),
263+
previous_node_kind: None,
252264
});
253265

254266
if grant_context.node_kind == "grant_table" {
@@ -276,6 +288,7 @@ impl<'a> CompletionContext<'a> {
276288
text: policy_context.node_text.into(),
277289
range: policy_context.node_range,
278290
kind: policy_context.node_kind.clone(),
291+
previous_node_kind: Some(policy_context.previous_node_kind),
279292
});
280293

281294
if policy_context.node_kind == "policy_table" {
@@ -295,7 +308,13 @@ impl<'a> CompletionContext<'a> {
295308
}
296309
"policy_role" => Some(WrappingClause::ToRoleAssignment),
297310
"policy_table" => Some(WrappingClause::From),
298-
_ => None,
311+
_ => {
312+
if policy_context.in_check_or_using_clause {
313+
Some(WrappingClause::PolicyCheck)
314+
} else {
315+
None
316+
}
317+
}
299318
};
300319
}
301320

@@ -785,7 +804,11 @@ impl<'a> CompletionContext<'a> {
785804
.is_some_and(|sib| kinds.contains(&sib.kind()))
786805
}
787806

788-
NodeUnderCursor::CustomNode { .. } => false,
807+
NodeUnderCursor::CustomNode {
808+
previous_node_kind, ..
809+
} => previous_node_kind
810+
.as_ref()
811+
.is_some_and(|k| kinds.contains(&k.as_str())),
789812
}
790813
})
791814
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /