@@ -18,7 +18,7 @@ impl TryFrom<&str> for ClauseType {
18
18
match value {
19
19
"select" => Ok ( Self :: Select ) ,
20
20
"where" => Ok ( Self :: Where ) ,
21
- "from" => Ok ( Self :: From ) ,
21
+ "from" | "keyword_from" => Ok ( Self :: From ) ,
22
22
"update" => Ok ( Self :: Update ) ,
23
23
"delete" => Ok ( Self :: Delete ) ,
24
24
_ => {
@@ -88,13 +88,22 @@ impl<'a> CompletionContext<'a> {
88
88
89
89
let mut cursor = self . tree . as_ref ( ) . unwrap ( ) . root_node ( ) . walk ( ) ;
90
90
91
- // go to the statement node that matches the position
91
+ /*
92
+ * The head node of any treesitter tree is always the "PROGRAM" node.
93
+ *
94
+ * We want to enter the next layer and focus on the child node that matches the user's cursor position.
95
+ * If there is no node under the users position, however, the cursor won't enter the next level – it
96
+ * will stay on the Program node.
97
+ *
98
+ * This might lead to an unexpected context or infinite recursion.
99
+ *
100
+ * We'll therefore adjust the cursor position such that it meets the last node of the AST.
101
+ * `select * from use {}` becomes `select * from use{}`.
102
+ */
92
103
let current_node_kind = cursor. node ( ) . kind ( ) ;
93
-
94
- dbg ! ( current_node_kind) ;
95
- dbg ! ( self . position) ;
96
-
97
- cursor. goto_first_child_for_byte ( self . position ) ;
104
+ while cursor. goto_first_child_for_byte ( self . position ) . is_none ( ) && self . position > 0 {
105
+ self . position -= 1 ;
106
+ }
98
107
99
108
self . gather_context_from_node ( cursor, current_node_kind) ;
100
109
}
@@ -107,9 +116,7 @@ impl<'a> CompletionContext<'a> {
107
116
let current_node = cursor. node ( ) ;
108
117
let current_node_kind = current_node. kind ( ) ;
109
118
110
- println ! ( "inside.." ) ;
111
- dbg ! ( current_node_kind) ;
112
-
119
+ // prevent infinite recursion – this can happen if we only have a PROGRAM node
113
120
if current_node_kind == previous_node_kind {
114
121
self . ts_node = Some ( current_node) ;
115
122
return ;
@@ -138,9 +145,14 @@ impl<'a> CompletionContext<'a> {
138
145
self . wrapping_clause_type = "where" . try_into ( ) . ok ( ) ;
139
146
}
140
147
148
+ "keyword_from" => {
149
+ self . wrapping_clause_type = "keyword_from" . try_into ( ) . ok ( ) ;
150
+ }
151
+
141
152
_ => { }
142
153
}
143
154
155
+ // We have arrived at the leaf node
144
156
if current_node. child_count ( ) == 0 {
145
157
self . ts_node = Some ( current_node) ;
146
158
return ;
@@ -153,7 +165,10 @@ impl<'a> CompletionContext<'a> {
153
165
154
166
#[ cfg( test) ]
155
167
mod tests {
156
- use crate :: { context:: CompletionContext , test_helper:: CURSOR_POS } ;
168
+ use crate :: {
169
+ context:: { ClauseType , CompletionContext } ,
170
+ test_helper:: { get_text_and_position, CURSOR_POS } ,
171
+ } ;
157
172
158
173
fn get_tree ( input : & str ) -> tree_sitter:: Tree {
159
174
let mut parser = tree_sitter:: Parser :: new ( ) ;
@@ -193,11 +208,11 @@ mod tests {
193
208
) ,
194
209
] ;
195
210
196
- for ( text, expected_clause) in test_cases {
197
- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
198
- let text = text. replace ( CURSOR_POS , "" ) ;
211
+ for ( query, expected_clause) in test_cases {
212
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
199
213
200
214
let tree = get_tree ( text. as_str ( ) ) ;
215
+
201
216
let params = crate :: CompletionParams {
202
217
position : ( position as u32 ) . into ( ) ,
203
218
text : text,
@@ -226,9 +241,8 @@ mod tests {
226
241
( format!( "Select * from u{}sers()" , CURSOR_POS ) , None ) ,
227
242
] ;
228
243
229
- for ( text, expected_schema) in test_cases {
230
- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
231
- let text = text. replace ( CURSOR_POS , "" ) ;
244
+ for ( query, expected_schema) in test_cases {
245
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
232
246
233
247
let tree = get_tree ( text. as_str ( ) ) ;
234
248
let params = crate :: CompletionParams {
@@ -261,9 +275,8 @@ mod tests {
261
275
) ,
262
276
] ;
263
277
264
- for ( text, is_invocation) in test_cases {
265
- let position = text. find ( CURSOR_POS ) . unwrap ( ) ;
266
- let text = text. replace ( CURSOR_POS , "" ) ;
278
+ for ( query, is_invocation) in test_cases {
279
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
267
280
268
281
let tree = get_tree ( text. as_str ( ) ) ;
269
282
let params = crate :: CompletionParams {
@@ -280,11 +293,42 @@ mod tests {
280
293
}
281
294
282
295
#[ test]
283
- fn get_ts_node_content_does_not_fail_on_error_nodes ( ) {
284
- let query = format ! ( "select * from {}" , CURSOR_POS ) ;
296
+ fn does_not_fail_on_leading_whitespace ( ) {
297
+ let cases = vec ! [
298
+ format!( "{} select * from" , CURSOR_POS ) ,
299
+ format!( " {} select * from" , CURSOR_POS ) ,
300
+ ] ;
301
+
302
+ for query in cases {
303
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
285
304
286
- let position = query. find ( CURSOR_POS ) . unwrap ( ) ;
287
- let text = query. replace ( CURSOR_POS , "" ) ;
305
+ let tree = get_tree ( text. as_str ( ) ) ;
306
+
307
+ let params = crate :: CompletionParams {
308
+ position : ( position as u32 ) . into ( ) ,
309
+ text : text,
310
+ tree : Some ( & tree) ,
311
+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
312
+ } ;
313
+
314
+ let ctx = CompletionContext :: new ( & params) ;
315
+
316
+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
317
+
318
+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "select" ) ) ;
319
+
320
+ assert_eq ! (
321
+ ctx. wrapping_clause_type,
322
+ Some ( crate :: context:: ClauseType :: Select )
323
+ ) ;
324
+ }
325
+ }
326
+
327
+ #[ test]
328
+ fn does_not_fail_on_trailing_whitespace ( ) {
329
+ let query = format ! ( "select * from {}" , CURSOR_POS ) ;
330
+
331
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
288
332
289
333
let tree = get_tree ( text. as_str ( ) ) ;
290
334
@@ -297,10 +341,60 @@ mod tests {
297
341
298
342
let ctx = CompletionContext :: new ( & params) ;
299
343
300
- let node = ctx. ts_node . map ( |n| n. clone ( ) ) ;
344
+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
345
+
346
+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "from" ) ) ;
347
+ assert_eq ! (
348
+ ctx. wrapping_clause_type,
349
+ Some ( crate :: context:: ClauseType :: From )
350
+ ) ;
351
+ }
352
+
353
+ #[ test]
354
+ fn does_not_fail_with_empty_statements ( ) {
355
+ let query = format ! ( "{}" , CURSOR_POS ) ;
356
+
357
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
358
+
359
+ let tree = get_tree ( text. as_str ( ) ) ;
360
+
361
+ let params = crate :: CompletionParams {
362
+ position : ( position as u32 ) . into ( ) ,
363
+ text : text,
364
+ tree : Some ( & tree) ,
365
+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
366
+ } ;
367
+
368
+ let ctx = CompletionContext :: new ( & params) ;
369
+
370
+ let node = ctx. ts_node . map ( |n| n. clone ( ) ) . unwrap ( ) ;
371
+
372
+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "" ) ) ;
373
+ assert_eq ! ( ctx. wrapping_clause_type, None ) ;
374
+ }
375
+
376
+ #[ test]
377
+ fn does_not_fail_on_incomplete_keywords ( ) {
378
+ // Instead of autocompleting "FROM", we'll assume that the user
379
+ // is selecting a certain column name, such as `frozen_account`.
380
+ let query = format ! ( "select * fro{}" , CURSOR_POS ) ;
381
+
382
+ let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
383
+
384
+ let tree = get_tree ( text. as_str ( ) ) ;
385
+
386
+ let params = crate :: CompletionParams {
387
+ position : ( position as u32 ) . into ( ) ,
388
+ text : text,
389
+ tree : Some ( & tree) ,
390
+ schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
391
+ } ;
392
+
393
+ let ctx = CompletionContext :: new ( & params) ;
301
394
302
- println ! ( " node kind: {}" , node . as_ref ( ) . unwrap( ) . kind ( ) ) ;
395
+ let node = ctx . ts_node . map ( |n| n . clone ( ) ) . unwrap ( ) ;
303
396
304
- ctx. get_ts_node_content ( node. unwrap ( ) ) ;
397
+ assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "fro" ) ) ;
398
+ assert_eq ! ( ctx. wrapping_clause_type, Some ( ClauseType :: Select ) ) ;
305
399
}
306
400
}
0 commit comments