Skip to content

Commit 2f31a92

Browse files
tawsifkamalcodegen-bot
and
codegen-bot
authored
Import Parsing (#664)
- parses imports now during tree-sitter code_block parsing in `parser.py` - Deals with top-level + dynamic imports - for typescript dynamic imports, it will search for imports by using a mixture of `get_all_descendants` and `get_first_descendant` calls to search for the appropriate call_expression for imports - is_wrapped_in functionality now works! - added parent_of_types, changed is_dynamic to now use the correct strategy instead of traversing up tree sitter nodes CG-10504 --------- Co-authored-by: codegen-bot <[email protected]>
1 parent a7e9325 commit 2f31a92

File tree

19 files changed

+288
-136
lines changed

19 files changed

+288
-136
lines changed

src/codegen/sdk/codebase/node_classes/node_classes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from codegen.sdk.core.file import SourceFile
1717
from codegen.sdk.core.function import Function
1818
from codegen.sdk.core.import_resolution import Import
19+
from codegen.sdk.core.interfaces.editable import Editable
1920
from codegen.sdk.core.statements.comment import Comment
2021
from codegen.sdk.core.symbol import Symbol
2122

@@ -33,7 +34,7 @@ class NodeClasses:
3334
function_call_cls: type[FunctionCall]
3435
comment_cls: type[Comment]
3536
bool_conversion: dict[bool, str]
36-
dynamic_import_parent_types: set[str]
37+
dynamic_import_parent_types: set[type[Editable]]
3738
symbol_map: dict[str, type[Symbol]] = field(default_factory=dict)
3839
expression_map: dict[str, type[Expression]] = field(default_factory=dict)
3940
type_map: dict[str, type[Type] | dict[str, type[Type]]] = field(default_factory=dict)

src/codegen/sdk/codebase/node_classes/py_node_classes.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression
1515
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
1616
from codegen.sdk.core.expressions.unpack import Unpack
17+
from codegen.sdk.core.function import Function
1718
from codegen.sdk.core.statements.comment import Comment
19+
from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
20+
from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
21+
from codegen.sdk.core.statements.switch_statement import SwitchStatement
22+
from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
23+
from codegen.sdk.core.statements.while_statement import WhileStatement
1824
from codegen.sdk.core.symbol_groups.dict import Dict
1925
from codegen.sdk.core.symbol_groups.list import List
2026
from codegen.sdk.core.symbol_groups.tuple import Tuple
@@ -29,6 +35,8 @@
2935
from codegen.sdk.python.expressions.string import PyString
3036
from codegen.sdk.python.expressions.union_type import PyUnionType
3137
from codegen.sdk.python.statements.import_statement import PyImportStatement
38+
from codegen.sdk.python.statements.match_case import PyMatchCase
39+
from codegen.sdk.python.statements.with_statement import WithStatement
3240

3341

3442
def parse_subscript(node: TSNode, file_node_id, ctx, parent):
@@ -110,16 +118,13 @@ def parse_subscript(node: TSNode, file_node_id, ctx, parent):
110118
False: "False",
111119
},
112120
dynamic_import_parent_types={
113-
"function_definition",
114-
"if_statement",
115-
"try_statement",
116-
"with_statement",
117-
"else_clause",
118-
"for_statement",
119-
"except_clause",
120-
"while_statement",
121-
"match_statement",
122-
"case_clause",
123-
"finally_clause",
121+
Function,
122+
IfBlockStatement,
123+
TryCatchStatement,
124+
WithStatement,
125+
ForLoopStatement,
126+
WhileStatement,
127+
SwitchStatement,
128+
PyMatchCase,
124129
},
125130
)

src/codegen/sdk/codebase/node_classes/ts_node_classes.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
1616
from codegen.sdk.core.expressions.unpack import Unpack
1717
from codegen.sdk.core.expressions.value import Value
18+
from codegen.sdk.core.function import Function
1819
from codegen.sdk.core.statements.comment import Comment
20+
from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
21+
from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
22+
from codegen.sdk.core.statements.switch_case import SwitchCase
23+
from codegen.sdk.core.statements.switch_statement import SwitchStatement
24+
from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
25+
from codegen.sdk.core.statements.while_statement import WhileStatement
1926
from codegen.sdk.core.symbol_groups.list import List
2027
from codegen.sdk.core.symbol_groups.type_parameters import TypeParameters
2128
from codegen.sdk.typescript.class_definition import TSClass
@@ -166,18 +173,12 @@ def parse_new(node: TSNode, *args):
166173
False: "false",
167174
},
168175
dynamic_import_parent_types={
169-
"function_declaration",
170-
"method_definition",
171-
"arrow_function",
172-
"if_statement",
173-
"try_statement",
174-
"else_clause",
175-
"catch_clause",
176-
"finally_clause",
177-
"while_statement",
178-
"for_statement",
179-
"do_statement",
180-
"switch_case",
181-
"switch_statement",
176+
Function,
177+
IfBlockStatement,
178+
TryCatchStatement,
179+
ForLoopStatement,
180+
WhileStatement,
181+
SwitchStatement,
182+
SwitchCase,
182183
},
183184
)

src/codegen/sdk/core/file.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,15 +459,10 @@ def parse(self, ctx: CodebaseContext) -> None:
459459
self.code_block = self._parse_code_block(self.ts_node)
460460

461461
self.code_block.parse()
462-
self._parse_imports()
463462
# We need to clear the valid symbol/import names before we start resolving exports since these can be outdated.
464463
self.invalidate()
465464
sort_editables(self._nodes)
466465

467-
@abstractmethod
468-
@commiter
469-
def _parse_imports(self) -> None: ...
470-
471466
@noapidoc
472467
@commiter
473468
def remove_internal_edges(self) -> None:

src/codegen/sdk/core/import_resolution.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -428,15 +428,7 @@ def my_function():
428428
bool: True if the import is dynamic (within a control flow or scope block),
429429
False if it's a top-level import.
430430
"""
431-
curr = self.ts_node
432-
433-
# always traverses upto the module level
434-
while curr:
435-
if curr.type in self.ctx.node_classes.dynamic_import_parent_types:
436-
return True
437-
curr = curr.parent
438-
439-
return False
431+
return self.parent_of_types(self.ctx.node_classes.dynamic_import_parent_types) is not None
440432

441433
####################################################################################################################
442434
# MANIPULATIONS

src/codegen/sdk/core/interfaces/editable.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def children_by_field_types(self, field_types: str | Iterable[str]) -> Generator
823823
@reader
824824
@noapidoc
825825
def child_by_field_types(self, field_types: str | Iterable[str]) -> Expression[Self] | None:
826-
"""Get child by field types."""
826+
"""Get child by fiexld types."""
827827
return next(self.children_by_field_types(field_types), None)
828828

829829
@property
@@ -1097,6 +1097,14 @@ def parent_of_type(self, type: type[T]) -> T | None:
10971097
return self.parent.parent_of_type(type)
10981098
return None
10991099

1100+
def parent_of_types(self, types: set[type[T]]) -> T | None:
1101+
"""Find the first ancestor of the node of the given type. Does not return itself"""
1102+
if self.parent and any(isinstance(self.parent, t) for t in types):
1103+
return self.parent
1104+
if self.parent is not self and self.parent is not None:
1105+
return self.parent.parent_of_types(types)
1106+
return None
1107+
11001108
@reader
11011109
def ancestors(self, type: type[T]) -> list[T]:
11021110
"""Find all ancestors of the node of the given type. Does not return itself"""

src/codegen/sdk/core/parser.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from codegen.sdk.core.expressions.placeholder_type import PlaceholderType
99
from codegen.sdk.core.expressions.value import Value
1010
from codegen.sdk.core.statements.symbol_statement import SymbolStatement
11-
from codegen.sdk.utils import find_first_function_descendant
11+
from codegen.sdk.utils import find_first_function_descendant, find_import_node
1212

1313
if TYPE_CHECKING:
1414
from tree_sitter import Node as TSNode
@@ -108,6 +108,7 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
108108
from codegen.sdk.typescript.statements.comment import TSComment
109109
from codegen.sdk.typescript.statements.for_loop_statement import TSForLoopStatement
110110
from codegen.sdk.typescript.statements.if_block_statement import TSIfBlockStatement
111+
from codegen.sdk.typescript.statements.import_statement import TSImportStatement
111112
from codegen.sdk.typescript.statements.labeled_statement import TSLabeledStatement
112113
from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement
113114
from codegen.sdk.typescript.statements.try_catch_statement import TSTryCatchStatement
@@ -117,11 +118,13 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
117118

118119
if node.type in self.expressions or node.type == "expression_statement":
119120
return [ExpressionStatement(node, file_node_id, ctx, parent, 0, expression_node=node)]
121+
120122
for child in node.named_children:
121123
# =====[ Functions + Methods ]=====
122124
if child.type in _VALID_TYPE_NAMES:
123125
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
124-
126+
elif child.type == "import_statement":
127+
statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
125128
# =====[ Classes ]=====
126129
elif child.type in ("class_declaration", "abstract_class_declaration"):
127130
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
@@ -132,7 +135,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
132135

133136
# =====[ Type Alias Declarations ]=====
134137
elif child.type == "type_alias_declaration":
135-
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
138+
if import_node := find_import_node(child):
139+
statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements), source_node=import_node))
140+
else:
141+
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
136142

137143
# =====[ Enum Declarations ]=====
138144
elif child.type == "enum_declaration":
@@ -142,11 +148,6 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
142148
elif child.type == "export_statement" or child.text.decode("utf-8") == "export *;":
143149
statements.append(ExportStatement(child, file_node_id, ctx, parent, len(statements)))
144150

145-
# =====[ Imports ] =====
146-
elif child.type == "import_statement":
147-
# statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
148-
pass # Temporarily opting to identify all imports using find_all_descendants
149-
150151
# =====[ Non-symbol statements ] =====
151152
elif child.type == "comment":
152153
statements.append(TSComment.from_code_block(child, parent, pos=len(statements)))
@@ -167,6 +168,8 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
167168
elif child.type in ["lexical_declaration", "variable_declaration"]:
168169
if function_node := find_first_function_descendant(child):
169170
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements), function_node))
171+
elif import_node := find_import_node(child):
172+
statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements), source_node=import_node))
170173
else:
171174
statements.append(
172175
TSAssignmentStatement.from_assignment(
@@ -176,6 +179,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
176179
elif child.type in ["public_field_definition", "property_signature", "enum_assignment"]:
177180
statements.append(TSAttribute(child, file_node_id, ctx, parent, pos=len(statements)))
178181
elif child.type == "expression_statement":
182+
if import_node := find_import_node(child):
183+
statements.append(TSImportStatement(child, file_node_id, ctx, parent, pos=len(statements), source_node=import_node))
184+
continue
185+
179186
for var in child.named_children:
180187
if var.type == "string":
181188
statements.append(TSComment.from_code_block(var, parent, pos=len(statements)))
@@ -185,7 +192,6 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
185192
statements.append(ExpressionStatement(child, file_node_id, ctx, parent, pos=len(statements), expression_node=var))
186193
elif child.type in self.expressions:
187194
statements.append(ExpressionStatement(child, file_node_id, ctx, parent, len(statements), expression_node=child))
188-
189195
else:
190196
self.log("Couldn't parse statement with type: %s", child.type)
191197
statements.append(Statement.from_code_block(child, parent, pos=len(statements)))
@@ -204,6 +210,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
204210
from codegen.sdk.python.statements.comment import PyComment
205211
from codegen.sdk.python.statements.for_loop_statement import PyForLoopStatement
206212
from codegen.sdk.python.statements.if_block_statement import PyIfBlockStatement
213+
from codegen.sdk.python.statements.import_statement import PyImportStatement
207214
from codegen.sdk.python.statements.match_statement import PyMatchStatement
208215
from codegen.sdk.python.statements.pass_statement import PyPassStatement
209216
from codegen.sdk.python.statements.try_catch_statement import PyTryCatchStatement
@@ -237,9 +244,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
237244

238245
# =====[ Imports ] =====
239246
elif child.type in ["import_statement", "import_from_statement", "future_import_statement"]:
240-
# statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements)))
241-
pass # Temporarily opting to identify all imports using find_all_descendants
242-
247+
statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements)))
243248
# =====[ Non-symbol statements ] =====
244249
elif child.type == "comment":
245250
statements.append(PyComment.from_code_block(child, parent, pos=len(statements)))

src/codegen/sdk/extensions/utils.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def find_all_descendants(
1313
type_names: Iterable[str] | str,
1414
max_depth: int | None = None,
1515
nested: bool = True,
16+
stop_at_first: str | None = None,
1617
) -> list[TSNode]: ...
1718
def find_line_start_and_end_nodes(node: TSNode) -> list[tuple[TSNode, TSNode]]:
1819
"""Returns a list of tuples of the start and end nodes of each line in the node"""

src/codegen/sdk/extensions/utils.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_all_identifiers(node: TSNode) -> list[TSNode]:
3131
return sorted(dict.fromkeys(identifiers), key=lambda x: x.start_byte)
3232

3333

34-
def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True) -> list[TSNode]:
34+
def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True, stop_at_first: str | None = None) -> list[TSNode]:
3535
if isinstance(type_names, str):
3636
type_names = [type_names]
3737
descendants = []
@@ -45,6 +45,9 @@ def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_dept
4545
if not nested and current_node != node:
4646
return
4747

48+
if stop_at_first and current_node.type == stop_at_first:
49+
return
50+
4851
for child in current_node.children:
4952
traverse(child, depth + 1)
5053

src/codegen/sdk/python/file.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from typing import TYPE_CHECKING
44

5-
from codegen.sdk.core.autocommit import commiter, reader, writer
5+
from codegen.sdk.core.autocommit import reader, writer
66
from codegen.sdk.core.file import SourceFile
77
from codegen.sdk.core.interface import Interface
88
from codegen.sdk.enums import ImportType
9-
from codegen.sdk.extensions.utils import cached_property, iter_all_descendants
9+
from codegen.sdk.extensions.utils import cached_property
1010
from codegen.sdk.python import PyAssignment
1111
from codegen.sdk.python.class_definition import PyClass
1212
from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock
@@ -15,7 +15,6 @@
1515
from codegen.sdk.python.import_resolution import PyImport
1616
from codegen.sdk.python.interfaces.has_block import PyHasBlock
1717
from codegen.sdk.python.statements.attribute import PyAttribute
18-
from codegen.sdk.python.statements.import_statement import PyImportStatement
1918
from codegen.shared.decorators.docs import noapidoc, py_apidoc
2019
from codegen.shared.enums.programming_language import ProgrammingLanguage
2120

@@ -59,12 +58,6 @@ def symbol_can_be_added(self, symbol: PySymbol) -> bool:
5958
"""
6059
return True
6160

62-
@noapidoc
63-
@commiter
64-
def _parse_imports(self) -> None:
65-
for import_node in iter_all_descendants(self.ts_node, frozenset({"import_statement", "import_from_statement", "future_import_statement"})):
66-
PyImportStatement(import_node, self.node_id, self.ctx, self.code_block, 0)
67-
6861
####################################################################################################################
6962
# GETTERS
7063
####################################################################################################################

src/codegen/sdk/typescript/file.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from typing import TYPE_CHECKING
55

6-
from codegen.sdk.core.autocommit import commiter, mover, reader, writer
6+
from codegen.sdk.core.autocommit import mover, reader, writer
77
from codegen.sdk.core.file import SourceFile
88
from codegen.sdk.core.interfaces.exportable import Exportable
99
from codegen.sdk.enums import ImportType, NodeType, SymbolType
@@ -18,8 +18,7 @@
1818
from codegen.sdk.typescript.interface import TSInterface
1919
from codegen.sdk.typescript.interfaces.has_block import TSHasBlock
2020
from codegen.sdk.typescript.namespace import TSNamespace
21-
from codegen.sdk.typescript.statements.import_statement import TSImportStatement
22-
from codegen.sdk.utils import calculate_base_path, find_all_descendants
21+
from codegen.sdk.utils import calculate_base_path
2322
from codegen.shared.decorators.docs import noapidoc, ts_apidoc
2423
from codegen.shared.enums.programming_language import ProgrammingLanguage
2524

@@ -228,18 +227,6 @@ def add_export_to_symbol(self, symbol: TSSymbol) -> None:
228227
# TODO: this should be in symbol.py class. Rename as `add_export`
229228
symbol.add_keyword("export")
230229

231-
@noapidoc
232-
@commiter
233-
def _parse_imports(self) -> None:
234-
import_nodes = find_all_descendants(self.ts_node, {"import_statement", "call_expression"})
235-
for import_node in import_nodes:
236-
if import_node.type == "import_statement":
237-
TSImportStatement(import_node, self.node_id, self.ctx, self.code_block, 0)
238-
elif import_node.type == "call_expression":
239-
function = import_node.child_by_field_name("function")
240-
if function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require"):
241-
TSImportStatement(import_node, self.node_id, self.ctx, self.code_block, 0)
242-
243230
@writer
244231
def remove_unused_exports(self) -> None:
245232
"""Removes unused exports from the file.

src/codegen/sdk/typescript/import_resolution.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,10 @@ def from_dynamic_import_statement(cls, import_call_node: TSNode, module_node: TS
451451
return imports
452452

453453
# If import statement is a variable declaration, capture the variable scoping keyword (const, let, var, etc)
454-
statement_node = import_statement_node.parent if import_statement_node.type in ["variable_declarator", "assignment_expression"] else import_statement_node
454+
if import_statement_node.type == "lexical_declaration":
455+
statement_node = import_statement_node
456+
else:
457+
statement_node = import_statement_node.parent if import_statement_node.type in ["variable_declarator", "assignment_expression"] else import_statement_node
455458

456459
# ==== [ Named dynamic import ] ====
457460
if name_node.type == "property_identifier":

0 commit comments

Comments
 (0)