8
8
from codegen .sdk .core .expressions .placeholder_type import PlaceholderType
9
9
from codegen .sdk .core .expressions .value import Value
10
10
from codegen .sdk .core .statements .symbol_statement import SymbolStatement
11
+ from codegen .sdk .extensions .utils import find_all_descendants , find_first_descendant
11
12
from codegen .sdk .utils import find_first_function_descendant
12
13
13
14
if TYPE_CHECKING :
@@ -80,6 +81,42 @@ def parse_expression(self, node: TSNode | None, file_node_id: NodeId, ctx: Codeb
80
81
ret .children
81
82
return ret
82
83
84
+ def get_import_node (self , node : TSNode ) -> TSNode | None :
85
+ """Get the import node from a node that may contain an import.
86
+ Returns None if the node does not contain an import.
87
+
88
+ Returns:
89
+ TSNode | None: The import_statement or call_expression node if it's an import, None otherwise
90
+ """
91
+ # Static imports
92
+ if node .type == "import_statement" :
93
+ return node
94
+
95
+ # Dynamic imports and requires can be either:
96
+ # 1. Inside expression_statement -> call_expression
97
+ # 2. Direct call_expression
98
+
99
+ # we only parse imports inside expressions and variable declarations
100
+ call_expression = find_first_descendant (node , ["call_expression" ])
101
+ if member_expression := find_first_descendant (node , ["member_expression" ]):
102
+ # there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module)
103
+ descendants = find_all_descendants (member_expression , ["call_expression" ])
104
+ if descendants :
105
+ import_node = descendants [- 1 ]
106
+ else :
107
+ # this means this is NOT a dynamic import()
108
+ return None
109
+ else :
110
+ import_node = call_expression
111
+
112
+ # thus we only consider the deepest one
113
+ if import_node :
114
+ function = import_node .child_by_field_name ("function" )
115
+ if function and (function .type == "import" or (function .type == "identifier" and function .text .decode ("utf-8" ) == "require" )):
116
+ return import_node
117
+
118
+ return None
119
+
83
120
def log_unparsed (self , node : TSNode ) -> None :
84
121
if self ._should_log and node .is_named and node .type not in self ._uncovered_nodes :
85
122
self ._uncovered_nodes .add (node .type )
@@ -108,6 +145,7 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
108
145
from codegen .sdk .typescript .statements .comment import TSComment
109
146
from codegen .sdk .typescript .statements .for_loop_statement import TSForLoopStatement
110
147
from codegen .sdk .typescript .statements .if_block_statement import TSIfBlockStatement
148
+ from codegen .sdk .typescript .statements .import_statement import TSImportStatement
111
149
from codegen .sdk .typescript .statements .labeled_statement import TSLabeledStatement
112
150
from codegen .sdk .typescript .statements .switch_statement import TSSwitchStatement
113
151
from codegen .sdk .typescript .statements .try_catch_statement import TSTryCatchStatement
@@ -117,11 +155,13 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
117
155
118
156
if node .type in self .expressions or node .type == "expression_statement" :
119
157
return [ExpressionStatement (node , file_node_id , ctx , parent , 0 , expression_node = node )]
158
+
120
159
for child in node .named_children :
121
160
# =====[ Functions + Methods ]=====
122
161
if child .type in _VALID_TYPE_NAMES :
123
162
statements .append (SymbolStatement (child , file_node_id , ctx , parent , len (statements )))
124
-
163
+ elif child .type == "import_statement" :
164
+ statements .append (TSImportStatement (child , file_node_id , ctx , parent , len (statements )))
125
165
# =====[ Classes ]=====
126
166
elif child .type in ("class_declaration" , "abstract_class_declaration" ):
127
167
statements .append (SymbolStatement (child , file_node_id , ctx , parent , len (statements )))
@@ -132,7 +172,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
132
172
133
173
# =====[ Type Alias Declarations ]=====
134
174
elif child .type == "type_alias_declaration" :
135
- statements .append (SymbolStatement (child , file_node_id , ctx , parent , len (statements )))
175
+ if import_node := self .get_import_node (child ):
176
+ statements .append (TSImportStatement (import_node , file_node_id , ctx , parent , len (statements )))
177
+ else :
178
+ statements .append (SymbolStatement (child , file_node_id , ctx , parent , len (statements )))
136
179
137
180
# =====[ Enum Declarations ]=====
138
181
elif child .type == "enum_declaration" :
@@ -142,10 +185,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
142
185
elif child .type == "export_statement" or child .text .decode ("utf-8" ) == "export *;" :
143
186
statements .append (ExportStatement (child , file_node_id , ctx , parent , len (statements )))
144
187
145
- # =====[ Imports ] =====
146
- elif child .type == "import_statement" :
147
- # statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
148
- pass # Temporarily opting to identify all imports using find_all_descendants
188
+ # # =====[ Imports ] =====
189
+ # elif child.type == "import_statement":
190
+ # # statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
191
+ # pass # Temporarily opting to identify all imports using find_all_descendants
149
192
150
193
# =====[ Non-symbol statements ] =====
151
194
elif child .type == "comment" :
@@ -167,6 +210,8 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
167
210
elif child .type in ["lexical_declaration" , "variable_declaration" ]:
168
211
if function_node := find_first_function_descendant (child ):
169
212
statements .append (SymbolStatement (child , file_node_id , ctx , parent , len (statements ), function_node ))
213
+ elif import_node := self .get_import_node (child ):
214
+ statements .append (TSImportStatement (import_node , file_node_id , ctx , parent , len (statements )))
170
215
else :
171
216
statements .append (
172
217
TSAssignmentStatement .from_assignment (
@@ -176,6 +221,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
176
221
elif child .type in ["public_field_definition" , "property_signature" , "enum_assignment" ]:
177
222
statements .append (TSAttribute (child , file_node_id , ctx , parent , pos = len (statements )))
178
223
elif child .type == "expression_statement" :
224
+ if import_node := self .get_import_node (child ):
225
+ statements .append (TSImportStatement (import_node , file_node_id , ctx , parent , pos = len (statements )))
226
+ continue
227
+
179
228
for var in child .named_children :
180
229
if var .type == "string" :
181
230
statements .append (TSComment .from_code_block (var , parent , pos = len (statements )))
@@ -185,7 +234,6 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
185
234
statements .append (ExpressionStatement (child , file_node_id , ctx , parent , pos = len (statements ), expression_node = var ))
186
235
elif child .type in self .expressions :
187
236
statements .append (ExpressionStatement (child , file_node_id , ctx , parent , len (statements ), expression_node = child ))
188
-
189
237
else :
190
238
self .log ("Couldn't parse statement with type: %s" , child .type )
191
239
statements .append (Statement .from_code_block (child , parent , pos = len (statements )))
@@ -204,6 +252,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
204
252
from codegen .sdk .python .statements .comment import PyComment
205
253
from codegen .sdk .python .statements .for_loop_statement import PyForLoopStatement
206
254
from codegen .sdk .python .statements .if_block_statement import PyIfBlockStatement
255
+ from codegen .sdk .python .statements .import_statement import PyImportStatement
207
256
from codegen .sdk .python .statements .match_statement import PyMatchStatement
208
257
from codegen .sdk .python .statements .pass_statement import PyPassStatement
209
258
from codegen .sdk .python .statements .try_catch_statement import PyTryCatchStatement
@@ -237,9 +286,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
237
286
238
287
# =====[ Imports ] =====
239
288
elif child .type in ["import_statement" , "import_from_statement" , "future_import_statement" ]:
240
- # statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements)))
241
- pass # Temporarily opting to identify all imports using find_all_descendants
242
-
289
+ statements .append (PyImportStatement (child , file_node_id , ctx , parent , len (statements )))
243
290
# =====[ Non-symbol statements ] =====
244
291
elif child .type == "comment" :
245
292
statements .append (PyComment .from_code_block (child , parent , pos = len (statements )))
0 commit comments