Skip to content

Commit 974ccd5

Browse files
committed
fixing errors
1 parent 772ac1f commit 974ccd5

File tree

4 files changed

+47
-52
lines changed

4 files changed

+47
-52
lines changed

src/codegen/sdk/core/parser.py

Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from codegen.sdk.core.expressions.placeholder_type import PlaceholderType
99
from codegen.sdk.core.expressions.value import Value
1010
from codegen.sdk.core.statements.symbol_statement import SymbolStatement
11-
from codegen.sdk.extensions.utils import find_all_descendants, find_first_descendant
12-
from codegen.sdk.utils import find_first_function_descendant
11+
from codegen.sdk.utils import find_first_function_descendant, find_import_node
1312

1413
if TYPE_CHECKING:
1514
from tree_sitter import Node as TSNode
@@ -81,42 +80,6 @@ def parse_expression(self, node: TSNode | None, file_node_id: NodeId, ctx: Codeb
8180
ret.children
8281
return ret
8382

84-
def get_import_node(self, node: TSNode) -> TSNode | None:
85-
"""Get the import node from a node that may contain an import.
86-
Returns None if the node does not contain an import.
87-
88-
Returns:
89-
TSNode | None: The import_statement or call_expression node if it's an import, None otherwise
90-
"""
91-
# Static imports
92-
if node.type == "import_statement":
93-
return node
94-
95-
# Dynamic imports and requires can be either:
96-
# 1. Inside expression_statement -> call_expression
97-
# 2. Direct call_expression
98-
99-
# we only parse imports inside expressions and variable declarations
100-
call_expression = find_first_descendant(node, ["call_expression"])
101-
if member_expression := find_first_descendant(node, ["member_expression"]):
102-
# there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module)
103-
descendants = find_all_descendants(member_expression, ["call_expression"])
104-
if descendants:
105-
import_node = descendants[-1]
106-
else:
107-
# this means this is NOT a dynamic import()
108-
return None
109-
else:
110-
import_node = call_expression
111-
112-
# thus we only consider the deepest one
113-
if import_node:
114-
function = import_node.child_by_field_name("function")
115-
if function and (function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require")):
116-
return import_node
117-
118-
return None
119-
12083
def log_unparsed(self, node: TSNode) -> None:
12184
if self._should_log and node.is_named and node.type not in self._uncovered_nodes:
12285
self._uncovered_nodes.add(node.type)
@@ -172,7 +135,7 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
172135

173136
# =====[ Type Alias Declarations ]=====
174137
elif child.type == "type_alias_declaration":
175-
if import_node := self.get_import_node(child):
138+
if import_node := find_import_node(child):
176139
statements.append(TSImportStatement(import_node, file_node_id, ctx, parent, len(statements)))
177140
else:
178141
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
@@ -185,11 +148,6 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
185148
elif child.type == "export_statement" or child.text.decode("utf-8") == "export *;":
186149
statements.append(ExportStatement(child, file_node_id, ctx, parent, len(statements)))
187150

188-
# # =====[ Imports ] =====
189-
# elif child.type == "import_statement":
190-
# # statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
191-
# pass # Temporarily opting to identify all imports using find_all_descendants
192-
193151
# =====[ Non-symbol statements ] =====
194152
elif child.type == "comment":
195153
statements.append(TSComment.from_code_block(child, parent, pos=len(statements)))
@@ -210,7 +168,7 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
210168
elif child.type in ["lexical_declaration", "variable_declaration"]:
211169
if function_node := find_first_function_descendant(child):
212170
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements), function_node))
213-
elif import_node := self.get_import_node(child):
171+
elif import_node := find_import_node(child):
214172
statements.append(TSImportStatement(import_node, file_node_id, ctx, parent, len(statements)))
215173
else:
216174
statements.append(
@@ -221,7 +179,7 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
221179
elif child.type in ["public_field_definition", "property_signature", "enum_assignment"]:
222180
statements.append(TSAttribute(child, file_node_id, ctx, parent, pos=len(statements)))
223181
elif child.type == "expression_statement":
224-
if import_node := self.get_import_node(child):
182+
if import_node := find_import_node(child):
225183
statements.append(TSImportStatement(import_node, file_node_id, ctx, parent, pos=len(statements)))
226184
continue
227185

src/codegen/sdk/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,43 @@ def find_first_function_descendant(node: TSNode) -> TSNode:
8787
return find_first_descendant(node=node, type_names=type_names, max_depth=2)
8888

8989

90+
def find_import_node(node: TSNode) -> TSNode | None:
91+
"""Get the import node from a node that may contain an import.
92+
Returns None if the node does not contain an import.
93+
94+
Returns:
95+
TSNode | None: The import_statement or call_expression node if it's an import, None otherwise
96+
"""
97+
# Static imports
98+
if node.type == "import_statement":
99+
return node
100+
101+
# Dynamic imports and requires can be either:
102+
# 1. Inside expression_statement -> call_expression
103+
# 2. Direct call_expression
104+
105+
# we only parse imports inside expressions and variable declarations
106+
107+
if member_expression := find_first_descendant(node, ["member_expression"]):
108+
# there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module)
109+
descendants = find_all_descendants(member_expression, ["call_expression"])
110+
if descendants:
111+
import_node = descendants[-1]
112+
else:
113+
# this means this is NOT a dynamic import()
114+
return None
115+
else:
116+
import_node = find_first_descendant(node, ["call_expression"])
117+
118+
# thus we only consider the deepest one
119+
if import_node:
120+
function = import_node.child_by_field_name("function")
121+
if function and (function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require")):
122+
return import_node
123+
124+
return None
125+
126+
90127
def find_index(target: TSNode, siblings: list[TSNode]) -> int:
91128
"""Returns the index of the target node in the list of siblings, or -1 if not found. Recursive implementation."""
92129
if target in siblings:

tests/unit/codegen/sdk/codebase/session/test_codebase_from_files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_from_files_python():
1111
assert len(codebase.files) == 2
1212
assert any(f.filepath.endswith("main.py") for f in codebase.files)
1313
assert any(f.filepath.endswith("utils.py") for f in codebase.files)
14-
assert any("from utils import add" in f.content for f in codebase.files)
14+
assert any("from utils import add" in f.source for f in codebase.files)
1515

1616

1717
def test_from_files_typescript():
@@ -22,7 +22,7 @@ def test_from_files_typescript():
2222
assert len(codebase.files) == 2
2323
assert any(f.filepath.endswith("index.ts") for f in codebase.files)
2424
assert any(f.filepath.endswith("utils.ts") for f in codebase.files)
25-
assert any("import { add }" in f.content for f in codebase.files)
25+
assert any("import { add }" in f.source for f in codebase.files)
2626

2727

2828
def test_from_files_empty():

tests/unit/codegen/sdk/codebase/session/test_codebase_from_string.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def hello():
1313
codebase = Codebase.from_string(code, language="python")
1414
assert len(codebase.files) == 1
1515
assert codebase.files[0].filepath.endswith("test.py")
16-
assert "def hello" in codebase.files[0].content
16+
assert "def hello" in codebase.files[0].source
1717

1818

1919
def test_from_string_typescript():
@@ -26,7 +26,7 @@ def test_from_string_typescript():
2626
codebase = Codebase.from_string(code, language="typescript")
2727
assert len(codebase.files) == 1
2828
assert codebase.files[0].filepath.endswith("test.ts")
29-
assert "function hello" in codebase.files[0].content
29+
assert "function hello" in codebase.files[0].source
3030

3131

3232
def test_from_string_with_enum():
@@ -42,14 +42,14 @@ def test_from_string_invalid_syntax():
4242
code = "this is not valid python"
4343
codebase = Codebase.from_string(code, language="python")
4444
assert len(codebase.files) == 1
45-
assert codebase.files[0].content == code
45+
assert codebase.files[0].source == code
4646

4747

4848
def test_from_string_empty():
4949
"""Test creating a codebase from empty string"""
5050
codebase = Codebase.from_string("", language="python")
5151
assert len(codebase.files) == 1
52-
assert codebase.files[0].content == ""
52+
assert codebase.files[0].source == ""
5353

5454

5555
def test_from_string_missing_language():

0 commit comments

Comments
 (0)