Skip to content

Commit d377327

Browse files
author
codegen-bot
committed
done
1 parent ccf7a26 commit d377327

File tree

14 files changed

+289
-146
lines changed

14 files changed

+289
-146
lines changed

src/codegen/sdk/codebase/node_classes/node_classes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from codegen.sdk.core.file import SourceFile
1717
from codegen.sdk.core.function import Function
1818
from codegen.sdk.core.import_resolution import Import
19+
from codegen.sdk.core.interfaces.editable import Editable
1920
from codegen.sdk.core.statements.comment import Comment
2021
from codegen.sdk.core.symbol import Symbol
2122

@@ -33,7 +34,7 @@ class NodeClasses:
3334
function_call_cls: type[FunctionCall]
3435
comment_cls: type[Comment]
3536
bool_conversion: dict[bool, str]
36-
dynamic_import_parent_types: set[str]
37+
dynamic_import_parent_types: set[type[Editable]]
3738
symbol_map: dict[str, type[Symbol]] = field(default_factory=dict)
3839
expression_map: dict[str, type[Expression]] = field(default_factory=dict)
3940
type_map: dict[str, type[Type] | dict[str, type[Type]]] = field(default_factory=dict)

src/codegen/sdk/codebase/node_classes/py_node_classes.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression
1515
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
1616
from codegen.sdk.core.expressions.unpack import Unpack
17+
from codegen.sdk.core.function import Function
1718
from codegen.sdk.core.statements.comment import Comment
19+
from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
20+
from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
21+
from codegen.sdk.core.statements.switch_statement import SwitchStatement
22+
from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
23+
from codegen.sdk.core.statements.while_statement import WhileStatement
1824
from codegen.sdk.core.symbol_groups.dict import Dict
1925
from codegen.sdk.core.symbol_groups.list import List
2026
from codegen.sdk.core.symbol_groups.tuple import Tuple
@@ -29,6 +35,8 @@
2935
from codegen.sdk.python.expressions.string import PyString
3036
from codegen.sdk.python.expressions.union_type import PyUnionType
3137
from codegen.sdk.python.statements.import_statement import PyImportStatement
38+
from codegen.sdk.python.statements.match_case import PyMatchCase
39+
from codegen.sdk.python.statements.with_statement import WithStatement
3240

3341

3442
def parse_subscript(node: TSNode, file_node_id, ctx, parent):
@@ -110,16 +118,13 @@ def parse_subscript(node: TSNode, file_node_id, ctx, parent):
110118
False: "False",
111119
},
112120
dynamic_import_parent_types={
113-
"function_definition",
114-
"if_statement",
115-
"try_statement",
116-
"with_statement",
117-
"else_clause",
118-
"for_statement",
119-
"except_clause",
120-
"while_statement",
121-
"match_statement",
122-
"case_clause",
123-
"finally_clause",
121+
Function,
122+
IfBlockStatement,
123+
TryCatchStatement,
124+
WithStatement,
125+
ForLoopStatement,
126+
WhileStatement,
127+
SwitchStatement,
128+
PyMatchCase,
124129
},
125130
)

src/codegen/sdk/codebase/node_classes/ts_node_classes.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
1616
from codegen.sdk.core.expressions.unpack import Unpack
1717
from codegen.sdk.core.expressions.value import Value
18+
from codegen.sdk.core.function import Function
1819
from codegen.sdk.core.statements.comment import Comment
20+
from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
21+
from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
22+
from codegen.sdk.core.statements.switch_case import SwitchCase
23+
from codegen.sdk.core.statements.switch_statement import SwitchStatement
24+
from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
25+
from codegen.sdk.core.statements.while_statement import WhileStatement
1926
from codegen.sdk.core.symbol_groups.list import List
2027
from codegen.sdk.core.symbol_groups.type_parameters import TypeParameters
2128
from codegen.sdk.typescript.class_definition import TSClass
@@ -166,18 +173,12 @@ def parse_new(node: TSNode, *args):
166173
False: "false",
167174
},
168175
dynamic_import_parent_types={
169-
"function_declaration",
170-
"method_definition",
171-
"arrow_function",
172-
"if_statement",
173-
"try_statement",
174-
"else_clause",
175-
"catch_clause",
176-
"finally_clause",
177-
"while_statement",
178-
"for_statement",
179-
"do_statement",
180-
"switch_case",
181-
"switch_statement",
176+
Function,
177+
IfBlockStatement,
178+
TryCatchStatement,
179+
ForLoopStatement,
180+
WhileStatement,
181+
SwitchStatement,
182+
SwitchCase,
182183
},
183184
)

src/codegen/sdk/core/codebase.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import os
77
import re
8+
import tempfile
89
from collections.abc import Generator
910
from contextlib import contextmanager
1011
from functools import cached_property
@@ -1333,19 +1334,16 @@ def from_string(
13331334
prog_lang = ProgrammingLanguage(language.upper()) if isinstance(language, str) else language
13341335
filename = "test.ts" if prog_lang == ProgrammingLanguage.TYPESCRIPT else "test.py"
13351336

1336-
# Create temporary directory
1337-
import tempfile
1337+
with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir:
1338+
logger.info(f"Using directory: {tmp_dir}")
13381339

1339-
tmp_dir = tempfile.mkdtemp(prefix="codegen_")
1340-
logger.info(f"Using directory: {tmp_dir}")
1341-
1342-
# Create codebase using factory
1343-
from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory
1340+
# Create codebase using factory
1341+
from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory
13441342

1345-
files = {filename: code}
1346-
codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang)
1347-
logger.info("Codebase initialization complete")
1348-
return codebase
1343+
files = {filename: code}
1344+
codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang)
1345+
logger.info("Codebase initialization complete")
1346+
return codebase
13491347

13501348
@classmethod
13511349
def from_files(
@@ -1411,18 +1409,15 @@ def from_files(
14111409
prog_lang = inferred_lang
14121410
logger.info(f"Using language: {prog_lang} ({'inferred' if language is None else 'explicit'})")
14131411

1414-
# Create temporary directory
1415-
import tempfile
1412+
with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir:
1413+
logger.info(f"Using directory: {tmp_dir}")
14161414

1417-
tmp_dir = tempfile.mkdtemp(prefix="codegen_")
1418-
logger.info(f"Using directory: {tmp_dir}")
1419-
1420-
# Create codebase using factory
1421-
from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory
1415+
# Create codebase using factory
1416+
from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory
14221417

1423-
codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang)
1424-
logger.info("Codebase initialization complete")
1425-
return codebase
1418+
codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang)
1419+
logger.info("Codebase initialization complete")
1420+
return codebase
14261421

14271422
def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str]]:
14281423
"""Get all modified symbols in a pull request"""

src/codegen/sdk/core/file.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -467,15 +467,10 @@ def parse(self, ctx: CodebaseContext) -> None:
467467
self.code_block = self._parse_code_block(self.ts_node)
468468

469469
self.code_block.parse()
470-
self._parse_imports()
471470
# We need to clear the valid symbol/import names before we start resolving exports since these can be outdated.
472471
self.invalidate()
473472
sort_editables(self._nodes)
474473

475-
@abstractmethod
476-
@commiter
477-
def _parse_imports(self) -> None: ...
478-
479474
@noapidoc
480475
@commiter
481476
def remove_internal_edges(self) -> None:

src/codegen/sdk/core/import_resolution.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -419,15 +419,7 @@ def my_function():
419419
bool: True if the import is dynamic (within a control flow or scope block),
420420
False if it's a top-level import.
421421
"""
422-
curr = self.ts_node
423-
424-
# always traverses upto the module level
425-
while curr:
426-
if curr.type in self.ctx.node_classes.dynamic_import_parent_types:
427-
return True
428-
curr = curr.parent
429-
430-
return False
422+
return self.parent_of_types(self.ctx.node_classes.dynamic_import_parent_types) is not None
431423

432424
####################################################################################################################
433425
# MANIPULATIONS

src/codegen/sdk/core/interfaces/editable.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def children_by_field_types(self, field_types: str | Iterable[str]) -> Generator
823823
@reader
824824
@noapidoc
825825
def child_by_field_types(self, field_types: str | Iterable[str]) -> Expression[Self] | None:
826-
"""Get child by field types."""
826+
"""Get child by fiexld types."""
827827
return next(self.children_by_field_types(field_types), None)
828828

829829
@property
@@ -1097,6 +1097,14 @@ def parent_of_type(self, type: type[T]) -> T | None:
10971097
return self.parent.parent_of_type(type)
10981098
return None
10991099

1100+
def parent_of_types(self, types: set[type[T]]) -> T | None:
1101+
"""Find the first ancestor of the node of the given type. Does not return itself"""
1102+
if self.parent and any(isinstance(self.parent, t) for t in types):
1103+
return self.parent
1104+
if self.parent is not self and self.parent is not None:
1105+
return self.parent.parent_of_types(types)
1106+
return None
1107+
11001108
@reader
11011109
def ancestors(self, type: type[T]) -> list[T]:
11021110
"""Find all ancestors of the node of the given type. Does not return itself"""

src/codegen/sdk/core/parser.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from codegen.sdk.core.expressions.placeholder_type import PlaceholderType
99
from codegen.sdk.core.expressions.value import Value
1010
from codegen.sdk.core.statements.symbol_statement import SymbolStatement
11+
from codegen.sdk.extensions.utils import find_all_descendants, find_first_descendant
1112
from codegen.sdk.utils import find_first_function_descendant
1213

1314
if TYPE_CHECKING:
@@ -80,6 +81,42 @@ def parse_expression(self, node: TSNode | None, file_node_id: NodeId, ctx: Codeb
8081
ret.children
8182
return ret
8283

84+
def get_import_node(self, node: TSNode) -> TSNode | None:
85+
"""Get the import node from a node that may contain an import.
86+
Returns None if the node does not contain an import.
87+
88+
Returns:
89+
TSNode | None: The import_statement or call_expression node if it's an import, None otherwise
90+
"""
91+
# Static imports
92+
if node.type == "import_statement":
93+
return node
94+
95+
# Dynamic imports and requires can be either:
96+
# 1. Inside expression_statement -> call_expression
97+
# 2. Direct call_expression
98+
99+
# we only parse imports inside expressions and variable declarations
100+
call_expression = find_first_descendant(node, ["call_expression"])
101+
if member_expression := find_first_descendant(node, ["member_expression"]):
102+
# there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module)
103+
descendants = find_all_descendants(member_expression, ["call_expression"])
104+
if descendants:
105+
import_node = descendants[-1]
106+
else:
107+
# this means this is NOT a dynamic import()
108+
return None
109+
else:
110+
import_node = call_expression
111+
112+
# thus we only consider the deepest one
113+
if import_node:
114+
function = import_node.child_by_field_name("function")
115+
if function and (function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require")):
116+
return import_node
117+
118+
return None
119+
83120
def log_unparsed(self, node: TSNode) -> None:
84121
if self._should_log and node.is_named and node.type not in self._uncovered_nodes:
85122
self._uncovered_nodes.add(node.type)
@@ -108,6 +145,7 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
108145
from codegen.sdk.typescript.statements.comment import TSComment
109146
from codegen.sdk.typescript.statements.for_loop_statement import TSForLoopStatement
110147
from codegen.sdk.typescript.statements.if_block_statement import TSIfBlockStatement
148+
from codegen.sdk.typescript.statements.import_statement import TSImportStatement
111149
from codegen.sdk.typescript.statements.labeled_statement import TSLabeledStatement
112150
from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement
113151
from codegen.sdk.typescript.statements.try_catch_statement import TSTryCatchStatement
@@ -117,11 +155,13 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
117155

118156
if node.type in self.expressions or node.type == "expression_statement":
119157
return [ExpressionStatement(node, file_node_id, ctx, parent, 0, expression_node=node)]
158+
120159
for child in node.named_children:
121160
# =====[ Functions + Methods ]=====
122161
if child.type in _VALID_TYPE_NAMES:
123162
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
124-
163+
elif child.type == "import_statement":
164+
statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
125165
# =====[ Classes ]=====
126166
elif child.type in ("class_declaration", "abstract_class_declaration"):
127167
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
@@ -132,7 +172,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
132172

133173
# =====[ Type Alias Declarations ]=====
134174
elif child.type == "type_alias_declaration":
135-
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
175+
if import_node := self.get_import_node(child):
176+
statements.append(TSImportStatement(import_node, file_node_id, ctx, parent, len(statements)))
177+
else:
178+
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
136179

137180
# =====[ Enum Declarations ]=====
138181
elif child.type == "enum_declaration":
@@ -142,10 +185,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
142185
elif child.type == "export_statement" or child.text.decode("utf-8") == "export *;":
143186
statements.append(ExportStatement(child, file_node_id, ctx, parent, len(statements)))
144187

145-
# =====[ Imports ] =====
146-
elif child.type == "import_statement":
147-
# statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
148-
pass # Temporarily opting to identify all imports using find_all_descendants
188+
# # =====[ Imports ] =====
189+
# elif child.type == "import_statement":
190+
# # statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
191+
# pass # Temporarily opting to identify all imports using find_all_descendants
149192

150193
# =====[ Non-symbol statements ] =====
151194
elif child.type == "comment":
@@ -167,6 +210,8 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
167210
elif child.type in ["lexical_declaration", "variable_declaration"]:
168211
if function_node := find_first_function_descendant(child):
169212
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements), function_node))
213+
elif import_node := self.get_import_node(child):
214+
statements.append(TSImportStatement(import_node, file_node_id, ctx, parent, len(statements)))
170215
else:
171216
statements.append(
172217
TSAssignmentStatement.from_assignment(
@@ -176,6 +221,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
176221
elif child.type in ["public_field_definition", "property_signature", "enum_assignment"]:
177222
statements.append(TSAttribute(child, file_node_id, ctx, parent, pos=len(statements)))
178223
elif child.type == "expression_statement":
224+
if import_node := self.get_import_node(child):
225+
statements.append(TSImportStatement(import_node, file_node_id, ctx, parent, pos=len(statements)))
226+
continue
227+
179228
for var in child.named_children:
180229
if var.type == "string":
181230
statements.append(TSComment.from_code_block(var, parent, pos=len(statements)))
@@ -185,7 +234,6 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
185234
statements.append(ExpressionStatement(child, file_node_id, ctx, parent, pos=len(statements), expression_node=var))
186235
elif child.type in self.expressions:
187236
statements.append(ExpressionStatement(child, file_node_id, ctx, parent, len(statements), expression_node=child))
188-
189237
else:
190238
self.log("Couldn't parse statement with type: %s", child.type)
191239
statements.append(Statement.from_code_block(child, parent, pos=len(statements)))
@@ -204,6 +252,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
204252
from codegen.sdk.python.statements.comment import PyComment
205253
from codegen.sdk.python.statements.for_loop_statement import PyForLoopStatement
206254
from codegen.sdk.python.statements.if_block_statement import PyIfBlockStatement
255+
from codegen.sdk.python.statements.import_statement import PyImportStatement
207256
from codegen.sdk.python.statements.match_statement import PyMatchStatement
208257
from codegen.sdk.python.statements.pass_statement import PyPassStatement
209258
from codegen.sdk.python.statements.try_catch_statement import PyTryCatchStatement
@@ -237,9 +286,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
237286

238287
# =====[ Imports ] =====
239288
elif child.type in ["import_statement", "import_from_statement", "future_import_statement"]:
240-
# statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements)))
241-
pass # Temporarily opting to identify all imports using find_all_descendants
242-
289+
statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements)))
243290
# =====[ Non-symbol statements ] =====
244291
elif child.type == "comment":
245292
statements.append(PyComment.from_code_block(child, parent, pos=len(statements)))

0 commit comments

Comments
 (0)