Skip to content

Commit 68abcbc

Browse files
tomcodgenfrainfreezetkfoss
authored
[CG-10833] feat: remove unpacking assignment (#505)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed --------- Co-authored-by: tomcodegen <[email protected]> Co-authored-by: tomcodgen <[email protected]>
1 parent 3f8fdad commit 68abcbc

File tree

4 files changed

+226
-0
lines changed

4 files changed

+226
-0
lines changed

src/codegen/sdk/python/assignment.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from __future__ import annotations
22

3+
from collections.abc import Collection
34
from typing import TYPE_CHECKING
45

6+
from codegen.sdk.codebase.transactions import RemoveTransaction, TransactionPriority
57
from codegen.sdk.core.assignment import Assignment
8+
from codegen.sdk.core.autocommit.decorators import remover
69
from codegen.sdk.core.expressions.multi_expression import MultiExpression
10+
from codegen.sdk.core.statements.assignment_statement import AssignmentStatement
711
from codegen.sdk.extensions.autocommit import reader
812
from codegen.sdk.python.symbol import PySymbol
913
from codegen.sdk.python.symbol_groups.comment_group import PyCommentGroup
@@ -96,3 +100,63 @@ def inline_comment(self) -> PyCommentGroup | None:
96100
"""
97101
# HACK: This is a temporary solution until comments are fixed
98102
return PyCommentGroup.from_symbol_inline_comments(self, self.ts_node.parent)
103+
104+
@noapidoc
105+
def _partial_remove_when_tuple(self, name, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True):
106+
idx = self.parent.left.index(name)
107+
value = self.value[idx]
108+
self.parent._values_scheduled_for_removal.append(value)
109+
# Special case for removing brackets of value
110+
if len(self.value) - len(self.parent._values_scheduled_for_removal) == 1:
111+
remainder = str(next(x for x in self.value if x not in self.parent._values_scheduled_for_removal and x != value))
112+
r_t = RemoveTransaction(self.value.start_byte, self.value.end_byte, self.file, priority=priority)
113+
self.transaction_manager.add_transaction(r_t)
114+
self.value.insert_at(self.value.start_byte, remainder, priority=priority)
115+
else:
116+
# Normal just remove one value
117+
value.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe)
118+
# Remove assignment name
119+
name.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe)
120+
121+
@noapidoc
122+
def _active_transactions_on_assignment_names(self, transaction_order: TransactionPriority) -> int:
123+
return [
124+
any(self.transaction_manager.get_transactions_at_range(self.file.path, start_byte=asgnmt.get_name().start_byte, end_byte=asgnmt.get_name().end_byte, transaction_order=transaction_order))
125+
for asgnmt in self.parent.assignments
126+
].count(True)
127+
128+
@remover
129+
def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None:
130+
"""Deletes this assignment and its related extended nodes (e.g. decorators, comments).
131+
132+
133+
Removes the current node and its extended nodes (e.g. decorators, comments) from the codebase.
134+
After removing the node, it handles cleanup of any surrounding formatting based on the context.
135+
136+
Args:
137+
delete_formatting (bool): Whether to delete surrounding whitespace and formatting. Defaults to True.
138+
priority (int): Priority of the removal transaction. Higher priority transactions are executed first. Defaults to 0.
139+
dedupe (bool): Whether to deduplicate removal transactions at the same location. Defaults to True.
140+
141+
Returns:
142+
None
143+
"""
144+
if self.ctx.config.feature_flags.unpacking_assignment_partial_removal:
145+
if isinstance(self.parent, AssignmentStatement) and len(self.parent.assignments) > 1:
146+
# Unpacking assignments
147+
name = self.get_name()
148+
if isinstance(self.value, Collection):
149+
if len(self.parent._values_scheduled_for_removal) < len(self.parent.assignments) - 1:
150+
self._partial_remove_when_tuple(name, delete_formatting, priority, dedupe)
151+
return
152+
else:
153+
self.parent._values_scheduled_for_removal = []
154+
else:
155+
transaction_count = self._active_transactions_on_assignment_names(TransactionPriority.Edit)
156+
throwaway = [asgnmt.name == "_" for asgnmt in self.parent.assignments].count(True)
157+
# Only edit if we didn't already omit all the other assignments, otherwise just remove the whole thing
158+
if transaction_count + throwaway < len(self.parent.assignments) - 1:
159+
name.edit("_", priority=priority, dedupe=dedupe)
160+
return
161+
162+
super().remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe)

src/codegen/sdk/python/statements/assignment_statement.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class PyAssignmentStatement(AssignmentStatement["PyCodeBlock", PyAssignment]):
3030

3131
assignment_types = {"assignment", "augmented_assignment", "named_expression"}
3232

33+
def __init__(self, ts_node, file_node_id, ctx, parent, pos, assignment_node):
34+
super().__init__(ts_node, file_node_id, ctx, parent, pos, assignment_node)
35+
self._values_scheduled_for_removal = []
36+
3337
@classmethod
3438
def from_assignment(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int, assignment_node: TSNode) -> PyAssignmentStatement:
3539
"""Creates a PyAssignmentStatement instance from a TreeSitter assignment node.

src/codegen/shared/configs/models/feature_flags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class CodebaseFeatureFlags(BaseSettings):
2828
generics: bool = True
2929
import_resolution_overrides: dict[str, str] = Field(default_factory=lambda: {})
3030
typescript: TypescriptConfig = Field(default_factory=TypescriptConfig)
31+
unpacking_assignment_partial_removal: bool = True
3132

3233

3334
class FeatureFlagsConfig(BaseModel):
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from codegen.sdk.codebase.factory.get_session import get_codebase_session
2+
3+
4+
def test_remove_unpacking_assignment(tmpdir) -> None:
5+
# language=python
6+
content = """foo,bar,buzz = (a, b, c)"""
7+
8+
with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase:
9+
file1 = codebase.get_file("test1.py")
10+
file2 = codebase.get_file("test2.py")
11+
file3 = codebase.get_file("test3.py")
12+
13+
foo = file1.get_symbol("foo")
14+
foo.remove()
15+
codebase.commit()
16+
17+
assert len(file1.symbols) == 2
18+
statement = file1.symbols[0].parent
19+
assert len(statement.assignments) == 2
20+
assert len(statement.value) == 2
21+
assert file1.source == """bar,buzz = (b, c)"""
22+
bar = file2.get_symbol("bar")
23+
bar.remove()
24+
codebase.commit()
25+
assert len(file2.symbols) == 2
26+
statement = file2.symbols[0].parent
27+
assert len(statement.assignments) == 2
28+
assert len(statement.value) == 2
29+
assert file2.source == """foo,buzz = (a, c)"""
30+
31+
buzz = file3.get_symbol("buzz")
32+
buzz.remove()
33+
codebase.commit()
34+
35+
assert len(file3.symbols) == 2
36+
statement = file3.symbols[0].parent
37+
assert len(statement.assignments) == 2
38+
assert len(statement.value) == 2
39+
assert file3.source == """foo,bar = (a, b)"""
40+
41+
file1_bar = file1.get_symbol("bar")
42+
43+
file1_bar.remove()
44+
codebase.commit()
45+
assert file1.source == """buzz = c"""
46+
47+
file1_buzz = file1.get_symbol("buzz")
48+
file1_buzz.remove()
49+
50+
codebase.commit()
51+
assert len(file1.symbols) == 0
52+
assert file1.source == """"""
53+
54+
55+
def test_remove_unpacking_assignment_funct(tmpdir) -> None:
56+
# language=python
57+
content = """foo,bar,buzz = f()"""
58+
59+
with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase:
60+
file1 = codebase.get_file("test1.py")
61+
file2 = codebase.get_file("test2.py")
62+
file3 = codebase.get_file("test3.py")
63+
64+
foo = file1.get_symbol("foo")
65+
foo.remove()
66+
codebase.commit()
67+
68+
assert len(file1.symbols) == 3
69+
statement = file1.symbols[0].parent
70+
assert len(statement.assignments) == 3
71+
assert file1.source == """_,bar,buzz = f()"""
72+
bar = file2.get_symbol("bar")
73+
bar.remove()
74+
codebase.commit()
75+
assert len(file2.symbols) == 3
76+
statement = file2.symbols[0].parent
77+
assert len(statement.assignments) == 3
78+
assert file2.source == """foo,_,buzz = f()"""
79+
80+
buzz = file3.get_symbol("buzz")
81+
buzz.remove()
82+
codebase.commit()
83+
84+
assert len(file3.symbols) == 3
85+
statement = file3.symbols[0].parent
86+
assert len(statement.assignments) == 3
87+
assert file3.source == """foo,bar,_ = f()"""
88+
89+
file1_bar = file1.get_symbol("bar")
90+
file1_buzz = file1.get_symbol("buzz")
91+
92+
file1_bar.remove()
93+
file1_buzz.remove()
94+
codebase.commit()
95+
assert len(file1.symbols) == 0
96+
assert file1.source == """"""
97+
98+
99+
def test_remove_unpacking_assignment_num(tmpdir) -> None:
100+
# language=python
101+
content = """a,b,c,d,e,f = (1, 2, 2, 4, 5, 3)"""
102+
103+
with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content}) as codebase:
104+
file1 = codebase.get_file("test1.py")
105+
106+
a = file1.get_symbol("a")
107+
d = file1.get_symbol("d")
108+
109+
a.remove()
110+
d.remove()
111+
codebase.commit()
112+
113+
assert len(file1.symbols) == 4
114+
statement = file1.symbols[0].parent
115+
assert len(statement.assignments) == 4
116+
assert file1.source == """b,c,e,f = (2, 2, 5, 3)"""
117+
118+
e = file1.get_symbol("e")
119+
c = file1.get_symbol("c")
120+
121+
e.remove()
122+
c.remove()
123+
codebase.commit()
124+
125+
assert len(file1.symbols) == 2
126+
statement = file1.symbols[0].parent
127+
assert len(statement.assignments) == 2
128+
assert file1.source == """b,f = (2, 3)"""
129+
130+
f = file1.get_symbol("f")
131+
132+
f.remove()
133+
codebase.commit()
134+
135+
assert len(file1.symbols) == 1
136+
statement = file1.symbols[0].parent
137+
assert len(statement.assignments) == 1
138+
assert file1.source == """b = 2"""
139+
file2 = codebase.get_file("test2.py")
140+
a = file2.get_symbol("a")
141+
d = file2.get_symbol("d")
142+
e = file2.get_symbol("e")
143+
c = file2.get_symbol("c")
144+
f = file2.get_symbol("f")
145+
b = file2.get_symbol("b")
146+
147+
a.remove()
148+
b.remove()
149+
c.remove()
150+
d.remove()
151+
e.remove()
152+
f.remove()
153+
154+
codebase.commit()
155+
156+
assert len(file2.symbols) == 0
157+
assert file2.source == """"""

0 commit comments

Comments
 (0)