1
+ use std:: ops:: Range ;
2
+
1
3
use pg_schema_cache:: SchemaCache ;
4
+ use pg_treesitter_queries:: { queries, TreeSitterQueriesExecutor } ;
2
5
3
6
use crate :: CompletionParams ;
4
7
@@ -52,10 +55,13 @@ pub(crate) struct CompletionContext<'a> {
52
55
pub schema_name : Option < String > ,
53
56
pub wrapping_clause_type : Option < ClauseType > ,
54
57
pub is_invocation : bool ,
58
+ pub wrapping_statement_range : Option < Range < usize > > ,
59
+
60
+ pub ts_query_executor : Option < TreeSitterQueriesExecutor < ' a > > ,
55
61
}
56
62
57
63
impl < ' a > CompletionContext < ' a > {
58
- pub fn new ( params : & ' a CompletionParams ) -> Self {
64
+ pub async fn new ( params : & ' a CompletionParams < ' a > ) -> Self {
59
65
let mut ctx = Self {
60
66
tree : params. tree ,
61
67
text : & params. text ,
@@ -65,14 +71,30 @@ impl<'a> CompletionContext<'a> {
65
71
ts_node : None ,
66
72
schema_name : None ,
67
73
wrapping_clause_type : None ,
74
+ wrapping_statement_range : None ,
68
75
is_invocation : false ,
76
+ ts_query_executor : None ,
69
77
} ;
70
78
71
79
ctx. gather_tree_context ( ) ;
80
+ ctx. dispatch_ts_queries ( ) . await ;
72
81
73
82
ctx
74
83
}
75
84
85
+ async fn dispatch_ts_queries ( & mut self ) {
86
+ let tree = match self . tree . as_ref ( ) {
87
+ None => return ,
88
+ Some ( t) => t,
89
+ } ;
90
+
91
+ let mut executor = TreeSitterQueriesExecutor :: new ( tree. root_node ( ) , self . text ) ;
92
+
93
+ executor. add_query_results :: < queries:: RelationMatch > ( ) . await ;
94
+
95
+ self . ts_query_executor = Some ( executor) ;
96
+ }
97
+
76
98
pub fn get_ts_node_content ( & self , ts_node : tree_sitter:: Node < ' a > ) -> Option < & ' a str > {
77
99
let source = self . text ;
78
100
match ts_node. utf8_text ( source. as_bytes ( ) ) {
@@ -100,36 +122,38 @@ impl<'a> CompletionContext<'a> {
100
122
* We'll therefore adjust the cursor position such that it meets the last node of the AST.
101
123
* `select * from use {}` becomes `select * from use{}`.
102
124
*/
103
- let current_node_kind = cursor. node ( ) . kind ( ) ;
125
+ let current_node = cursor. node ( ) ;
104
126
while cursor. goto_first_child_for_byte ( self . position ) . is_none ( ) && self . position > 0 {
105
127
self . position -= 1 ;
106
128
}
107
129
108
- self . gather_context_from_node ( cursor, current_node_kind ) ;
130
+ self . gather_context_from_node ( cursor, current_node ) ;
109
131
}
110
132
111
133
fn gather_context_from_node (
112
134
& mut self ,
113
135
mut cursor : tree_sitter:: TreeCursor < ' a > ,
114
- previous_node_kind : & str ,
136
+ previous_node : tree_sitter :: Node < ' a > ,
115
137
) {
116
138
let current_node = cursor. node ( ) ;
117
- let current_node_kind = current_node. kind ( ) ;
118
139
119
140
// prevent infinite recursion – this can happen if we only have a PROGRAM node
120
- if current_node_kind == previous_node_kind {
141
+ if current_node . kind ( ) == previous_node . kind ( ) {
121
142
self . ts_node = Some ( current_node) ;
122
143
return ;
123
144
}
124
145
125
- match previous_node_kind {
126
- "statement" => self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ,
146
+ match previous_node. kind ( ) {
147
+ "statement" => {
148
+ self . wrapping_clause_type = current_node. kind ( ) . try_into ( ) . ok ( ) ;
149
+ self . wrapping_statement_range = Some ( previous_node. byte_range ( ) ) ;
150
+ }
127
151
"invocation" => self . is_invocation = true ,
128
152
129
153
_ => { }
130
154
}
131
155
132
- match current_node_kind {
156
+ match current_node . kind ( ) {
133
157
"object_reference" => {
134
158
let txt = self . get_ts_node_content ( current_node) ;
135
159
if let Some ( txt) = txt {
@@ -159,7 +183,7 @@ impl<'a> CompletionContext<'a> {
159
183
}
160
184
161
185
cursor. goto_first_child_for_byte ( self . position ) ;
162
- self . gather_context_from_node ( cursor, current_node_kind ) ;
186
+ self . gather_context_from_node ( cursor, current_node ) ;
163
187
}
164
188
}
165
189
@@ -179,8 +203,8 @@ mod tests {
179
203
parser. parse ( input, None ) . expect ( "Unable to parse tree" )
180
204
}
181
205
182
- #[ test]
183
- fn identifies_clauses ( ) {
206
+ #[ tokio :: test]
207
+ async fn identifies_clauses ( ) {
184
208
let test_cases = vec ! [
185
209
( format!( "Select {}* from users;" , CURSOR_POS ) , "select" ) ,
186
210
( format!( "Select * from u{};" , CURSOR_POS ) , "from" ) ,
@@ -220,14 +244,14 @@ mod tests {
220
244
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
221
245
} ;
222
246
223
- let ctx = CompletionContext :: new ( & params) ;
247
+ let ctx = CompletionContext :: new ( & params) . await ;
224
248
225
249
assert_eq ! ( ctx. wrapping_clause_type, expected_clause. try_into( ) . ok( ) ) ;
226
250
}
227
251
}
228
252
229
- #[ test]
230
- fn identifies_schema ( ) {
253
+ #[ tokio :: test]
254
+ async fn identifies_schema ( ) {
231
255
let test_cases = vec ! [
232
256
(
233
257
format!( "Select * from private.u{}" , CURSOR_POS ) ,
@@ -252,14 +276,14 @@ mod tests {
252
276
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
253
277
} ;
254
278
255
- let ctx = CompletionContext :: new ( & params) ;
279
+ let ctx = CompletionContext :: new ( & params) . await ;
256
280
257
281
assert_eq ! ( ctx. schema_name, expected_schema. map( |f| f. to_string( ) ) ) ;
258
282
}
259
283
}
260
284
261
- #[ test]
262
- fn identifies_invocation ( ) {
285
+ #[ tokio :: test]
286
+ async fn identifies_invocation ( ) {
263
287
let test_cases = vec ! [
264
288
( format!( "Select * from u{}sers" , CURSOR_POS ) , false ) ,
265
289
( format!( "Select * from u{}sers()" , CURSOR_POS ) , true ) ,
@@ -286,14 +310,14 @@ mod tests {
286
310
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
287
311
} ;
288
312
289
- let ctx = CompletionContext :: new ( & params) ;
313
+ let ctx = CompletionContext :: new ( & params) . await ;
290
314
291
315
assert_eq ! ( ctx. is_invocation, is_invocation) ;
292
316
}
293
317
}
294
318
295
- #[ test]
296
- fn does_not_fail_on_leading_whitespace ( ) {
319
+ #[ tokio :: test]
320
+ async fn does_not_fail_on_leading_whitespace ( ) {
297
321
let cases = vec ! [
298
322
format!( "{} select * from" , CURSOR_POS ) ,
299
323
format!( " {} select * from" , CURSOR_POS ) ,
@@ -311,7 +335,7 @@ mod tests {
311
335
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
312
336
} ;
313
337
314
- let ctx = CompletionContext :: new ( & params) ;
338
+ let ctx = CompletionContext :: new ( & params) . await ;
315
339
316
340
let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
317
341
@@ -324,8 +348,8 @@ mod tests {
324
348
}
325
349
}
326
350
327
- #[ test]
328
- fn does_not_fail_on_trailing_whitespace ( ) {
351
+ #[ tokio :: test]
352
+ async fn does_not_fail_on_trailing_whitespace ( ) {
329
353
let query = format ! ( "select * from {}" , CURSOR_POS ) ;
330
354
331
355
let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
@@ -339,7 +363,7 @@ mod tests {
339
363
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
340
364
} ;
341
365
342
- let ctx = CompletionContext :: new ( & params) ;
366
+ let ctx = CompletionContext :: new ( & params) . await ;
343
367
344
368
let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
345
369
@@ -350,8 +374,8 @@ mod tests {
350
374
) ;
351
375
}
352
376
353
- #[ test]
354
- fn does_not_fail_with_empty_statements ( ) {
377
+ #[ tokio :: test]
378
+ async fn does_not_fail_with_empty_statements ( ) {
355
379
let query = format ! ( "{}" , CURSOR_POS ) ;
356
380
357
381
let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
@@ -365,16 +389,16 @@ mod tests {
365
389
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
366
390
} ;
367
391
368
- let ctx = CompletionContext :: new ( & params) ;
392
+ let ctx = CompletionContext :: new ( & params) . await ;
369
393
370
394
let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
371
395
372
396
assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "" ) ) ;
373
397
assert_eq ! ( ctx. wrapping_clause_type, None ) ;
374
398
}
375
399
376
- #[ test]
377
- fn does_not_fail_on_incomplete_keywords ( ) {
400
+ #[ tokio :: test]
401
+ async fn does_not_fail_on_incomplete_keywords ( ) {
378
402
// Instead of autocompleting "FROM", we'll assume that the user
379
403
// is selecting a certain column name, such as `frozen_account`.
380
404
let query = format ! ( "select * fro{}" , CURSOR_POS ) ;
@@ -390,7 +414,7 @@ mod tests {
390
414
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
391
415
} ;
392
416
393
- let ctx = CompletionContext :: new ( & params) ;
417
+ let ctx = CompletionContext :: new ( & params) . await ;
394
418
395
419
let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
396
420
0 commit comments