Skip to content

Commit 0b4d9a6

Browse files
committed
done
1 parent 1945ee6 commit 0b4d9a6

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

src/codegen/sdk/extensions/utils.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def find_all_descendants(
1313
type_names: Iterable[str] | str,
1414
max_depth: int | None = None,
1515
nested: bool = True,
16+
stop_at_first: str | None = None,
1617
) -> list[TSNode]: ...
1718
def find_line_start_and_end_nodes(node: TSNode) -> list[tuple[TSNode, TSNode]]:
1819
"""Returns a list of tuples of the start and end nodes of each line in the node"""

src/codegen/sdk/extensions/utils.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_all_identifiers(node: TSNode) -> list[TSNode]:
3131
return sorted(dict.fromkeys(identifiers), key=lambda x: x.start_byte)
3232

3333

34-
def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True) -> list[TSNode]:
34+
def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True, stop_at_first: str | None = None) -> list[TSNode]:
3535
if isinstance(type_names, str):
3636
type_names = [type_names]
3737
descendants = []
@@ -45,6 +45,9 @@ def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_dept
4545
if not nested and current_node != node:
4646
return
4747

48+
if stop_at_first and current_node.type == stop_at_first:
49+
return
50+
4851
for child in current_node.children:
4952
traverse(child, depth + 1)
5053

src/codegen/sdk/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,11 @@ def find_import_node(node: TSNode) -> TSNode | None:
104104

105105
# we only parse imports inside expressions and variable declarations
106106

107+
# import_nodes = [_node for _node in find_all_descendants(node, ["call_expression", "statement_block"], nested=False) if _node.type == "call_expression"]
108+
107109
if member_expression := find_first_descendant(node, ["member_expression"]):
108110
# there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module)
109-
descendants = find_all_descendants(member_expression, ["call_expression"])
111+
descendants = find_all_descendants(member_expression, ["call_expression"], stop_at_first="statement_block")
110112
if descendants:
111113
import_node = descendants[-1]
112114
else:

0 commit comments

Comments
 (0)