Skip to content

Commit ad541e3

Browse files
committed
adding attribute_chain
1 parent d77f368 commit ad541e3

File tree

6 files changed

+385
-0
lines changed

6 files changed

+385
-0
lines changed

src/codegen/sdk/core/detached_symbols/argument.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ def __init__(self, node: TSNode, positional_idx: int, parent: FunctionCall) -> N
5252
self._name_node = self._parse_expression(name_node, default=Name)
5353
self._value_node = self._parse_expression(_value_node)
5454

55+
def __repr__(self) -> str:
56+
keyword = f"keyword={self.name}, " if self.name else ""
57+
value = f"value='{self.value}', " if self.value else ""
58+
type = f"type={self.type}" if self.type else ""
59+
60+
return f"Argument({keyword}{value}{type})"
61+
5562
@noapidoc
5663
@classmethod
5764
def from_argument_list(cls, node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: FunctionCall) -> MultiExpression[Parent, Argument]:

src/codegen/sdk/core/detached_symbols/function_call.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,18 @@ def function_calls(self) -> list[FunctionCall]:
628628
# calls.append(call)
629629
return sort_editables(calls, dedupe=False)
630630

631+
@property
632+
@reader
633+
def attribute_chain(self) -> list[FunctionCall | Name]:
634+
if isinstance(self.get_name(), ChainedAttribute): # child is chainedAttribute. MEANING that this is likely in the middle or the last function call of a chained function call chain.
635+
return self.get_name().attribute_chain
636+
elif isinstance(
637+
self.parent, ChainedAttribute
638+
): # does not have child chainedAttribute, but parent is chainedAttribute. MEANING that this is likely the TOP function call of a chained function call chain.
639+
return self.parent.attribute_chain
640+
else: # this is a standalone function call
641+
return [self]
642+
631643
@property
632644
@noapidoc
633645
def descendant_symbols(self) -> list[Importable]:

src/codegen/sdk/core/expressions/chained_attribute.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from codegen.shared.decorators.docs import apidoc, noapidoc
1616

1717
if TYPE_CHECKING:
18+
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
1819
from codegen.sdk.core.interfaces.has_name import HasName
1920
from codegen.sdk.core.interfaces.importable import Importable
2021

22+
2123
Object = TypeVar("Object", bound="Chainable")
2224
Attribute = TypeVar("Attribute", bound="Resolvable")
2325
Parent = TypeVar("Parent", bound="Expression")
@@ -74,6 +76,41 @@ def attribute(self) -> Attribute:
7476
"""
7577
return self._attribute
7678

79+
@property
80+
@reader
81+
def attribute_chain(self) -> list["FunctionCall | Name"]:
82+
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
83+
84+
ret = []
85+
curr = self
86+
87+
# Traverse backwards in code (children of tree node)
88+
while isinstance(curr, ChainedAttribute):
89+
curr = curr.object
90+
91+
if isinstance(curr, FunctionCall):
92+
ret.insert(0, curr)
93+
curr = curr.get_name()
94+
elif isinstance(curr, ChainedAttribute):
95+
ret.insert(0, curr.attribute)
96+
97+
# This means that we have reached the base of the chain and the first item was an attribute (i.e a.b.c.func())
98+
if isinstance(curr, Name) and not isinstance(curr.parent, FunctionCall):
99+
ret.insert(0, curr)
100+
101+
curr = self
102+
103+
# Traversing forward in code (parents of tree node). Will add the current node as well
104+
while isinstance(curr, ChainedAttribute) or isinstance(curr, FunctionCall):
105+
if isinstance(curr, FunctionCall):
106+
ret.append(curr)
107+
elif isinstance(curr, ChainedAttribute) and not isinstance(curr.parent, FunctionCall):
108+
ret.append(curr.attribute)
109+
110+
curr = curr.parent
111+
112+
return ret
113+
77114
@property
78115
def object(self) -> Object:
79116
"""Returns the object that contains the attribute being looked up.

src/codegen/sdk/core/symbol_group.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def __init__(self, file_node_id: NodeId, G: CodebaseGraph, parent: Parent, node:
3737
node = children[0].ts_node
3838
super().__init__(node, file_node_id, G, parent)
3939

40+
def __repr__(self) -> str:
41+
return f"Collection({self.symbols})" if self.symbols is not None else super().__repr__()
42+
4043
def _init_children(self): ...
4144

4245
@repr_func # HACK
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from codegen.sdk.codebase.factory.get_session import get_codebase_session
2+
3+
4+
def test_attribute_chain_query_builder(tmpdir) -> None:
5+
# language=python
6+
content = """
7+
def query():
8+
# Test chained method calls with function at start
9+
QueryBuilder().select("name", "age").from_table("users").where("age > 18").order_by("name")
10+
"""
11+
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
12+
file = codebase.get_file("test.py")
13+
query = file.get_function("query")
14+
calls = query.function_calls
15+
assert len(calls) == 5
16+
order_by = calls[0] # Last call in chain
17+
where = calls[1]
18+
from_table = calls[2]
19+
select = calls[3]
20+
query_builder = calls[4] # First call in chain
21+
22+
# Test attribute chain from different positions
23+
# From first call (QueryBuilder())
24+
chain = query_builder.attribute_chain
25+
assert len(chain) == 5
26+
assert chain[0] == query_builder
27+
assert chain[1] == select
28+
assert chain[2] == from_table
29+
assert chain[3] == where
30+
assert chain[4] == order_by
31+
32+
# From middle call (from_table())
33+
chain = from_table.attribute_chain
34+
assert len(chain) == 5
35+
assert chain[0] == query_builder
36+
assert chain[1] == select
37+
assert chain[2] == from_table
38+
assert chain[3] == where
39+
assert chain[4] == order_by
40+
41+
# From last call (order_by())
42+
chain = order_by.attribute_chain
43+
assert len(chain) == 5
44+
assert chain[0] == query_builder
45+
assert chain[1] == select
46+
assert chain[2] == from_table
47+
assert chain[3] == where
48+
assert chain[4] == order_by
49+
50+
51+
def test_attribute_chain_mixed_properties(tmpdir) -> None:
52+
# language=python
53+
content = """
54+
def query():
55+
# Test mix of properties and function calls
56+
QueryBuilder().a.select("name", "age").from_table("users").where("age > 18").b.order_by("name").c
57+
"""
58+
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
59+
file = codebase.get_file("test.py")
60+
query = file.get_function("query")
61+
calls = query.function_calls
62+
63+
# Get function calls in order
64+
order_by = calls[0] # Last function call
65+
where = calls[1]
66+
from_table = calls[2]
67+
select = calls[3]
68+
query_builder = calls[4] # First function call
69+
70+
# Test from first call
71+
chain = query_builder.attribute_chain
72+
assert len(chain) == 8 # 5 function calls + 3 properties (a, b, c)
73+
assert chain[0] == query_builder
74+
assert chain[1].source == "a" # Property
75+
assert chain[2] == select
76+
assert chain[3] == from_table
77+
assert chain[4] == where
78+
assert chain[5].source == "b" # Property
79+
assert chain[6] == order_by
80+
assert chain[7].source == "c" # Property
81+
82+
83+
def test_attribute_chain_only_properties(tmpdir) -> None:
84+
# language=python
85+
content = """
86+
def test():
87+
# Test chain with only properties
88+
a.b.c.func()
89+
"""
90+
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
91+
file = codebase.get_file("test.py")
92+
test = file.get_function("test")
93+
calls = test.function_calls
94+
assert len(calls) == 1
95+
func = calls[0]
96+
97+
chain = func.attribute_chain
98+
assert len(chain) == 4
99+
assert chain[0].source == "a"
100+
assert chain[1].source == "b"
101+
assert chain[2].source == "c"
102+
assert chain[3] == func
103+
104+
105+
def test_attribute_chain_nested_calls(tmpdir) -> None:
106+
# language=python
107+
content = """
108+
def test():
109+
# Test nested function calls (not chained)
110+
a(b(c()))
111+
"""
112+
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
113+
file = codebase.get_file("test.py")
114+
test = file.get_function("test")
115+
calls = test.function_calls
116+
assert len(calls) == 3
117+
a = calls[0]
118+
b = calls[1]
119+
c = calls[2]
120+
121+
# Each call should have its own single-element chain
122+
assert a.attribute_chain == [a]
123+
assert b.attribute_chain == [b]
124+
assert c.attribute_chain == [c]

0 commit comments

Comments
 (0)