1
+ mod policy_parser;
2
+
1
3
use std:: collections:: { HashMap , HashSet } ;
2
4
3
5
use pgt_schema_cache:: SchemaCache ;
6
+ use pgt_text_size:: TextRange ;
4
7
use pgt_treesitter_queries:: {
5
8
TreeSitterQueriesExecutor ,
6
9
queries:: { self , QueryResult } ,
7
10
} ;
8
11
9
- use crate :: sanitization:: SanitizedCompletionParams ;
12
+ use crate :: {
13
+ NodeText ,
14
+ context:: policy_parser:: { PolicyParser , PolicyStmtKind } ,
15
+ sanitization:: SanitizedCompletionParams ,
16
+ } ;
10
17
11
18
#[ derive( Debug , PartialEq , Eq , Hash ) ]
12
19
pub enum WrappingClause < ' a > {
@@ -18,12 +25,8 @@ pub enum WrappingClause<'a> {
18
25
} ,
19
26
Update ,
20
27
Delete ,
21
- }
22
-
23
- #[ derive( PartialEq , Eq , Debug ) ]
24
- pub ( crate ) enum NodeText < ' a > {
25
- Replaced ,
26
- Original ( & ' a str ) ,
28
+ PolicyName ,
29
+ ToRoleAssignment ,
27
30
}
28
31
29
32
#[ derive( PartialEq , Eq , Hash , Debug ) ]
@@ -47,6 +50,45 @@ pub enum WrappingNode {
47
50
Assignment ,
48
51
}
49
52
53
+ #[ derive( Debug ) ]
54
+ pub ( crate ) enum NodeUnderCursor < ' a > {
55
+ TsNode ( tree_sitter:: Node < ' a > ) ,
56
+ CustomNode {
57
+ text : NodeText ,
58
+ range : TextRange ,
59
+ kind : String ,
60
+ } ,
61
+ }
62
+
63
+ impl NodeUnderCursor < ' _ > {
64
+ pub fn start_byte ( & self ) -> usize {
65
+ match self {
66
+ NodeUnderCursor :: TsNode ( node) => node. start_byte ( ) ,
67
+ NodeUnderCursor :: CustomNode { range, .. } => range. start ( ) . into ( ) ,
68
+ }
69
+ }
70
+
71
+ pub fn end_byte ( & self ) -> usize {
72
+ match self {
73
+ NodeUnderCursor :: TsNode ( node) => node. end_byte ( ) ,
74
+ NodeUnderCursor :: CustomNode { range, .. } => range. end ( ) . into ( ) ,
75
+ }
76
+ }
77
+
78
+ pub fn kind ( & self ) -> & str {
79
+ match self {
80
+ NodeUnderCursor :: TsNode ( node) => node. kind ( ) ,
81
+ NodeUnderCursor :: CustomNode { kind, .. } => kind. as_str ( ) ,
82
+ }
83
+ }
84
+ }
85
+
86
+ impl < ' a > From < tree_sitter:: Node < ' a > > for NodeUnderCursor < ' a > {
87
+ fn from ( node : tree_sitter:: Node < ' a > ) -> Self {
88
+ NodeUnderCursor :: TsNode ( node)
89
+ }
90
+ }
91
+
50
92
impl TryFrom < & str > for WrappingNode {
51
93
type Error = String ;
52
94
@@ -77,7 +119,7 @@ impl TryFrom<String> for WrappingNode {
77
119
}
78
120
79
121
pub ( crate ) struct CompletionContext < ' a > {
80
- pub node_under_cursor : Option < tree_sitter :: Node < ' a > > ,
122
+ pub node_under_cursor : Option < NodeUnderCursor < ' a > > ,
81
123
82
124
pub tree : & ' a tree_sitter:: Tree ,
83
125
pub text : & ' a str ,
@@ -137,12 +179,49 @@ impl<'a> CompletionContext<'a> {
137
179
is_in_error_node : false ,
138
180
} ;
139
181
140
- ctx. gather_tree_context ( ) ;
141
- ctx. gather_info_from_ts_queries ( ) ;
182
+ // policy handling is important to Supabase, but they are a PostgreSQL specific extension,
183
+ // so the tree_sitter_sql language does not support it.
184
+ // We infer the context manually.
185
+ if PolicyParser :: looks_like_policy_stmt ( & params. text ) {
186
+ ctx. gather_policy_context ( ) ;
187
+ } else {
188
+ ctx. gather_tree_context ( ) ;
189
+ ctx. gather_info_from_ts_queries ( ) ;
190
+ }
142
191
143
192
ctx
144
193
}
145
194
195
+ fn gather_policy_context ( & mut self ) {
196
+ let policy_context = PolicyParser :: get_context ( self . text , self . position ) ;
197
+
198
+ self . node_under_cursor = Some ( NodeUnderCursor :: CustomNode {
199
+ text : policy_context. node_text . into ( ) ,
200
+ range : policy_context. node_range ,
201
+ kind : policy_context. node_kind . clone ( ) ,
202
+ } ) ;
203
+
204
+ if policy_context. node_kind == "policy_table" {
205
+ self . schema_or_alias_name = policy_context. schema_name . clone ( ) ;
206
+ }
207
+
208
+ if policy_context. table_name . is_some ( ) {
209
+ let mut new = HashSet :: new ( ) ;
210
+ new. insert ( policy_context. table_name . unwrap ( ) ) ;
211
+ self . mentioned_relations
212
+ . insert ( policy_context. schema_name , new) ;
213
+ }
214
+
215
+ self . wrapping_clause_type = match policy_context. node_kind . as_str ( ) {
216
+ "policy_name" if policy_context. statement_kind != PolicyStmtKind :: Create => {
217
+ Some ( WrappingClause :: PolicyName )
218
+ }
219
+ "policy_role" => Some ( WrappingClause :: ToRoleAssignment ) ,
220
+ "policy_table" => Some ( WrappingClause :: From ) ,
221
+ _ => None ,
222
+ } ;
223
+ }
224
+
146
225
fn gather_info_from_ts_queries ( & mut self ) {
147
226
let stmt_range = self . wrapping_statement_range . as_ref ( ) ;
148
227
let sql = self . text ;
@@ -196,24 +275,30 @@ impl<'a> CompletionContext<'a> {
196
275
}
197
276
}
198
277
199
- pub fn get_ts_node_content ( & self , ts_node : tree_sitter:: Node < ' a > ) -> Option < NodeText < ' a > > {
278
+ fn get_ts_node_content ( & self , ts_node : & tree_sitter:: Node < ' a > ) -> Option < NodeText > {
200
279
let source = self . text ;
201
280
ts_node. utf8_text ( source. as_bytes ( ) ) . ok ( ) . map ( |txt| {
202
281
if SanitizedCompletionParams :: is_sanitized_token ( txt) {
203
282
NodeText :: Replaced
204
283
} else {
205
- NodeText :: Original ( txt)
284
+ NodeText :: Original ( txt. into ( ) )
206
285
}
207
286
} )
208
287
}
209
288
210
289
pub fn get_node_under_cursor_content ( & self ) -> Option < String > {
211
- self . node_under_cursor
212
- . and_then ( |n| self . get_ts_node_content ( n) )
213
- . and_then ( |txt| match txt {
290
+ match self . node_under_cursor . as_ref ( ) ? {
291
+ NodeUnderCursor :: TsNode ( node) => {
292
+ self . get_ts_node_content ( node) . and_then ( |nt| match nt {
293
+ NodeText :: Replaced => None ,
294
+ NodeText :: Original ( c) => Some ( c. to_string ( ) ) ,
295
+ } )
296
+ }
297
+ NodeUnderCursor :: CustomNode { text, .. } => match text {
214
298
NodeText :: Replaced => None ,
215
299
NodeText :: Original ( c) => Some ( c. to_string ( ) ) ,
216
- } )
300
+ } ,
301
+ }
217
302
}
218
303
219
304
fn gather_tree_context ( & mut self ) {
@@ -251,7 +336,7 @@ impl<'a> CompletionContext<'a> {
251
336
252
337
// prevent infinite recursion – this can happen if we only have a PROGRAM node
253
338
if current_node_kind == parent_node_kind {
254
- self . node_under_cursor = Some ( current_node) ;
339
+ self . node_under_cursor = Some ( NodeUnderCursor :: from ( current_node) ) ;
255
340
return ;
256
341
}
257
342
@@ -290,7 +375,7 @@ impl<'a> CompletionContext<'a> {
290
375
291
376
match current_node_kind {
292
377
"object_reference" | "field" => {
293
- let content = self . get_ts_node_content ( current_node) ;
378
+ let content = self . get_ts_node_content ( & current_node) ;
294
379
if let Some ( node_txt) = content {
295
380
match node_txt {
296
381
NodeText :: Original ( txt) => {
@@ -322,7 +407,7 @@ impl<'a> CompletionContext<'a> {
322
407
323
408
// We have arrived at the leaf node
324
409
if current_node. child_count ( ) == 0 {
325
- self . node_under_cursor = Some ( current_node) ;
410
+ self . node_under_cursor = Some ( NodeUnderCursor :: from ( current_node) ) ;
326
411
return ;
327
412
}
328
413
@@ -335,11 +420,11 @@ impl<'a> CompletionContext<'a> {
335
420
node : tree_sitter:: Node < ' a > ,
336
421
) -> Option < WrappingClause < ' a > > {
337
422
if node. kind ( ) . starts_with ( "keyword_" ) {
338
- if let Some ( txt) = self . get_ts_node_content ( node) . and_then ( |txt| match txt {
423
+ if let Some ( txt) = self . get_ts_node_content ( & node) . and_then ( |txt| match txt {
339
424
NodeText :: Original ( txt) => Some ( txt) ,
340
425
NodeText :: Replaced => None ,
341
426
} ) {
342
- match txt {
427
+ match txt. as_str ( ) {
343
428
"where" => return Some ( WrappingClause :: Where ) ,
344
429
"update" => return Some ( WrappingClause :: Update ) ,
345
430
"select" => return Some ( WrappingClause :: Select ) ,
@@ -389,11 +474,14 @@ impl<'a> CompletionContext<'a> {
389
474
#[ cfg( test) ]
390
475
mod tests {
391
476
use crate :: {
392
- context:: { CompletionContext , NodeText , WrappingClause } ,
477
+ NodeText ,
478
+ context:: { CompletionContext , WrappingClause } ,
393
479
sanitization:: SanitizedCompletionParams ,
394
480
test_helper:: { CURSOR_POS , get_text_and_position} ,
395
481
} ;
396
482
483
+ use super :: NodeUnderCursor ;
484
+
397
485
fn get_tree ( input : & str ) -> tree_sitter:: Tree {
398
486
let mut parser = tree_sitter:: Parser :: new ( ) ;
399
487
parser
@@ -552,17 +640,22 @@ mod tests {
552
640
553
641
let ctx = CompletionContext :: new ( & params) ;
554
642
555
- let node = ctx. node_under_cursor . unwrap ( ) ;
643
+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
556
644
557
- assert_eq ! (
558
- ctx. get_ts_node_content( node) ,
559
- Some ( NodeText :: Original ( "select" ) )
560
- ) ;
645
+ match node {
646
+ NodeUnderCursor :: TsNode ( node) => {
647
+ assert_eq ! (
648
+ ctx. get_ts_node_content( node) ,
649
+ Some ( NodeText :: Original ( "select" . into( ) ) )
650
+ ) ;
561
651
562
- assert_eq ! (
563
- ctx. wrapping_clause_type,
564
- Some ( crate :: context:: WrappingClause :: Select )
565
- ) ;
652
+ assert_eq ! (
653
+ ctx. wrapping_clause_type,
654
+ Some ( crate :: context:: WrappingClause :: Select )
655
+ ) ;
656
+ }
657
+ _ => unreachable ! ( ) ,
658
+ }
566
659
}
567
660
}
568
661
@@ -583,12 +676,17 @@ mod tests {
583
676
584
677
let ctx = CompletionContext :: new ( & params) ;
585
678
586
- let node = ctx. node_under_cursor . unwrap ( ) ;
679
+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
587
680
588
- assert_eq ! (
589
- ctx. get_ts_node_content( node) ,
590
- Some ( NodeText :: Original ( "from" ) )
591
- ) ;
681
+ match node {
682
+ NodeUnderCursor :: TsNode ( node) => {
683
+ assert_eq ! (
684
+ ctx. get_ts_node_content( node) ,
685
+ Some ( NodeText :: Original ( "from" . into( ) ) )
686
+ ) ;
687
+ }
688
+ _ => unreachable ! ( ) ,
689
+ }
592
690
}
593
691
594
692
#[ test]
@@ -608,10 +706,18 @@ mod tests {
608
706
609
707
let ctx = CompletionContext :: new ( & params) ;
610
708
611
- let node = ctx. node_under_cursor . unwrap ( ) ;
709
+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
612
710
613
- assert_eq ! ( ctx. get_ts_node_content( node) , Some ( NodeText :: Original ( "" ) ) ) ;
614
- assert_eq ! ( ctx. wrapping_clause_type, None ) ;
711
+ match node {
712
+ NodeUnderCursor :: TsNode ( node) => {
713
+ assert_eq ! (
714
+ ctx. get_ts_node_content( node) ,
715
+ Some ( NodeText :: Original ( "" . into( ) ) )
716
+ ) ;
717
+ assert_eq ! ( ctx. wrapping_clause_type, None ) ;
718
+ }
719
+ _ => unreachable ! ( ) ,
720
+ }
615
721
}
616
722
617
723
#[ test]
@@ -633,12 +739,17 @@ mod tests {
633
739
634
740
let ctx = CompletionContext :: new ( & params) ;
635
741
636
- let node = ctx. node_under_cursor . unwrap ( ) ;
742
+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
637
743
638
- assert_eq ! (
639
- ctx. get_ts_node_content( node) ,
640
- Some ( NodeText :: Original ( "fro" ) )
641
- ) ;
642
- assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
744
+ match node {
745
+ NodeUnderCursor :: TsNode ( node) => {
746
+ assert_eq ! (
747
+ ctx. get_ts_node_content( node) ,
748
+ Some ( NodeText :: Original ( "fro" . into( ) ) )
749
+ ) ;
750
+ assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
751
+ }
752
+ _ => unreachable ! ( ) ,
753
+ }
643
754
}
644
755
}
0 commit comments