-
Notifications
You must be signed in to change notification settings - Fork 52
[CG-8871] fix: removing dead code removes code that should stay #355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a72b816
1d7c4f8
4b88f58
b7670f8
4976c09
b357b8c
8b8b7ad
f2a30b8
9402e81
8f96ee6
9fffb72
d99d6b8
e2f310c
98e439c
376041a
bd89175
d04c670
21ffb79
99516f6
47d7fcf
87e6925
c00323b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,11 @@ | |
|
||
from typing import TYPE_CHECKING | ||
|
||
from codegen.sdk.codebase.transactions import RemoveTransaction, TransactionPriority | ||
from codegen.sdk.core.assignment import Assignment | ||
from codegen.sdk.core.autocommit.decorators import remover | ||
from codegen.sdk.core.expressions.multi_expression import MultiExpression | ||
from codegen.sdk.core.symbol_groups.collection import Collection | ||
from codegen.sdk.extensions.autocommit import reader | ||
from codegen.sdk.python.symbol import PySymbol | ||
from codegen.sdk.python.symbol_groups.comment_group import PyCommentGroup | ||
|
@@ -33,8 +36,8 @@ | |
|
||
left_node = ts_node.child_by_field_name("left") | ||
right_node = ts_node.child_by_field_name("right") | ||
assignments = cls._from_left_and_right_nodes(ts_node, file_node_id, ctx, parent, left_node, right_node) | ||
Check failure on line 39 in src/codegen/sdk/python/assignment.py
|
||
return MultiExpression(ts_node, file_node_id, ctx, parent, assignments) | ||
|
||
@classmethod | ||
def from_named_expression(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyAssignmentStatement) -> MultiExpression[PyAssignmentStatement, PyAssignment]: | ||
|
@@ -60,8 +63,8 @@ | |
|
||
left_node = ts_node.child_by_field_name("name") | ||
right_node = ts_node.child_by_field_name("value") | ||
assignments = cls._from_left_and_right_nodes(ts_node, file_node_id, ctx, parent, left_node, right_node) | ||
Check failure on line 66 in src/codegen/sdk/python/assignment.py
|
||
return MultiExpression(ts_node, file_node_id, ctx, parent, assignments) | ||
|
||
@property | ||
@reader | ||
|
@@ -96,3 +99,70 @@ | |
""" | ||
# HACK: This is a temporary solution until comments are fixed | ||
return PyCommentGroup.from_symbol_inline_comments(self, self.ts_node.parent) | ||
|
||
@remover | ||
def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: | ||
"""Deletes this assignment and its related extended nodes (e.g. decorators, comments). | ||
|
||
|
||
Removes the current node and its extended nodes (e.g. decorators, comments) from the codebase. | ||
After removing the node, it handles cleanup of any surrounding formatting based on the context. | ||
|
||
Args: | ||
delete_formatting (bool): Whether to delete surrounding whitespace and formatting. Defaults to True. | ||
priority (int): Priority of the removal transaction. Higher priority transactions are executed first. Defaults to 0. | ||
dedupe (bool): Whether to deduplicate removal transactions at the same location. Defaults to True. | ||
|
||
Returns: | ||
None | ||
""" | ||
if getattr(self.parent, "assignments", None) and len(self.parent.assignments) > 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do isinstance here? |
||
# Unpacking assignments | ||
name = self.get_name() | ||
if isinstance(self.value, Collection): | ||
# Tuples | ||
transaction_count = [ | ||
any( | ||
self.transaction_manager.get_transactions_at_range( | ||
self.file.path, start_byte=asgnmt.get_name().start_byte, end_byte=asgnmt.get_name().end_byte, transaction_order=TransactionPriority.Remove | ||
) | ||
) | ||
for asgnmt in self.parent.assignments | ||
].count(True) | ||
# Check for existing transactions | ||
if transaction_count < len(self.parent.assignments) - 1: | ||
idx = self.parent.left.index(name) | ||
value = self.value[idx] | ||
removal_queue_values = getattr(self.parent, "removal_queue", []) | ||
self.parent.removal_queue = removal_queue_values | ||
removal_queue_values.append(str(value)) | ||
if len(self.value) - transaction_count == 2: | ||
remainder = str(next(x for x in self.value if x not in removal_queue_values)) | ||
r_t = RemoveTransaction(self.value.start_byte, self.value.end_byte, self.file, priority=priority) | ||
self.transaction_manager.add_transaction(r_t) | ||
self.value.insert_at(self.value.start_byte, remainder, priority=priority) | ||
else: | ||
value.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) | ||
name.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) | ||
return | ||
else: | ||
transaction_count = [ | ||
any( | ||
self.transaction_manager.get_transactions_at_range( | ||
self.file.path, start_byte=asgnmt.get_name().start_byte, end_byte=asgnmt.get_name().end_byte, transaction_order=TransactionPriority.Edit | ||
) | ||
) | ||
for asgnmt in self.parent.assignments | ||
].count(True) | ||
throwaway = [asgnmt.name == "_" for asgnmt in self.parent.assignments].count(True) | ||
if transaction_count + throwaway < len(self.parent.assignments) - 1: | ||
name.edit("_", priority=priority, dedupe=dedupe) | ||
return | ||
if getattr(self.parent, "removal_queue", None): | ||
for node in self.extended_nodes: | ||
transactions = self.transaction_manager.get_transactions_at_range(self.file.path, start_byte=node.start_byte, end_byte=node.end_byte) | ||
for transaction in transactions: | ||
self.transaction_manager.queued_transactions[self.file.path].remove(transaction) | ||
|
||
for node in self.extended_nodes: | ||
node._remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,19 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from typing import TYPE_CHECKING | ||
from collections.abc import Generator | ||
from typing import TYPE_CHECKING, Self, override | ||
|
||
from codegen.sdk.core.autocommit import reader | ||
from codegen.sdk.core.expressions import Name | ||
from codegen.sdk.core.import_resolution import ExternalImportResolver, Import, ImportResolution | ||
from codegen.sdk.enums import ImportType, NodeType | ||
from codegen.sdk.extensions.resolution import ResolutionStack | ||
from codegen.shared.decorators.docs import noapidoc, py_apidoc | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Generator | ||
|
||
from tree_sitter import Node as TSNode | ||
|
||
from codegen.sdk.codebase.codebase_context import CodebaseContext | ||
|
@@ -28,6 +32,10 @@ | |
class PyImport(Import["PyFile"]): | ||
"""Extends Import for Python codebases.""" | ||
|
||
def __init__(self, ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type=ImportType.UNKNOWN): | ||
super().__init__(ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type) | ||
self.requesting_names = set() | ||
|
||
@reader | ||
def is_module_import(self) -> bool: | ||
"""Determines if the import is a module-level or wildcard import. | ||
|
@@ -117,13 +125,13 @@ | |
filepath = module_source.replace(".", "/") + ".py" | ||
filepath = os.path.join(base_path, filepath) | ||
if file := self.ctx.get_file(filepath): | ||
symbol = file.get_node_by_name(symbol_name) | ||
symbol = file.get_node_wildcard_resolves_for(symbol_name) | ||
return ImportResolution(from_file=file, symbol=symbol) | ||
|
||
# =====[ Check if `module/__init__.py` file exists in the graph ]===== | ||
filepath = filepath.replace(".py", "/__init__.py") | ||
if from_file := self.ctx.get_file(filepath): | ||
symbol = from_file.get_node_by_name(symbol_name) | ||
symbol = from_file.get_node_wildcard_resolves_for(symbol_name) | ||
return ImportResolution(from_file=from_file, symbol=symbol) | ||
|
||
# =====[ Case: Can't resolve the import ]===== | ||
|
@@ -133,6 +141,11 @@ | |
if base_path == "src": | ||
# Try "test" next | ||
return self.resolve_import(base_path="test", add_module_name=add_module_name) | ||
if base_path == "test" and module_source: | ||
# Try to resolve assuming package nested in repo | ||
possible_package_base_path = module_source.split(".")[0] | ||
if possible_package_base_path not in ("test", "src"): | ||
return self.resolve_import(base_path=possible_package_base_path, add_module_name=add_module_name) | ||
|
||
# if not G_override: | ||
# for resolver in ctx.import_resolvers: | ||
|
@@ -232,6 +245,33 @@ | |
imports.append(imp) | ||
return imports | ||
|
||
@reader | ||
@noapidoc | ||
@override | ||
def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is the right approach. It will create duplicated edges. For example def foo():
...
def bar():
... from filea import *
foo() # will be a usage of foo and bar For some context, we regular imports are resolved as follows.
Since it's impossible to know which symbol you're importing in a wildcard, wildcards follow the following pattern
For example, in the above case, we'd call _resolved_types on the wildcard for foo. But it'll create edges from the import to foo and from the call site to foo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seemed to work as expected, I ve added additional UTs, have I missed anything? |
||
"""Resolve the types used by this import.""" | ||
ix_seen = set() | ||
|
||
aliased = self.is_aliased_import() | ||
if imported := self._imported_symbol(resolve_exports=True): | ||
if getattr(imported, "is_wildcard_import", False): | ||
imported.set_requesting_names(self) | ||
yield from self.with_resolution_frame(imported, direct=False, aliased=aliased) | ||
else: | ||
yield ResolutionStack(self, aliased=aliased) | ||
|
||
if self.is_wildcard_import(): | ||
for name, wildcard_import in self.names: | ||
if name in self.requesting_names: | ||
yield from [frame.parent_frame for frame in wildcard_import.resolved_type_frames] | ||
|
||
@noapidoc | ||
def set_requesting_names(self, requester: PyImport): | ||
if requester.is_wildcard_import(): | ||
self.requesting_names.update(requester.requesting_names) | ||
else: | ||
self.requesting_names.add(requester.name) | ||
|
||
@property | ||
@reader | ||
def import_specifier(self) -> Editable: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from codegen.sdk.codebase.factory.get_session import get_codebase_session | ||
|
||
|
||
def test_remove_unpacking_assignment(tmpdir) -> None: | ||
# language=python | ||
content = """foo,bar,buzz = (a, b, c)""" | ||
|
||
with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase: | ||
file1 = codebase.get_file("test1.py") | ||
file2 = codebase.get_file("test2.py") | ||
file3 = codebase.get_file("test3.py") | ||
|
||
foo = file1.get_symbol("foo") | ||
foo.remove() | ||
codebase.commit() | ||
|
||
assert len(file1.symbols) == 2 | ||
statement = file1.symbols[0].parent | ||
assert len(statement.assignments) == 2 | ||
assert len(statement.value) == 2 | ||
assert file1.source == """bar,buzz = (b, c)""" | ||
bar = file2.get_symbol("bar") | ||
bar.remove() | ||
codebase.commit() | ||
assert len(file2.symbols) == 2 | ||
statement = file2.symbols[0].parent | ||
assert len(statement.assignments) == 2 | ||
assert len(statement.value) == 2 | ||
assert file2.source == """foo,buzz = (a, c)""" | ||
|
||
buzz = file3.get_symbol("buzz") | ||
buzz.remove() | ||
codebase.commit() | ||
|
||
assert len(file3.symbols) == 2 | ||
statement = file3.symbols[0].parent | ||
assert len(statement.assignments) == 2 | ||
assert len(statement.value) == 2 | ||
assert file3.source == """foo,bar = (a, b)""" | ||
|
||
file1_bar = file1.get_symbol("bar") | ||
|
||
file1_bar.remove() | ||
codebase.commit() | ||
assert file1.source == """buzz = c""" | ||
|
||
file1_buzz = file1.get_symbol("buzz") | ||
file1_buzz.remove() | ||
|
||
codebase.commit() | ||
assert len(file1.symbols) == 0 | ||
assert file1.source == """""" | ||
|
||
|
||
def test_remove_unpacking_assignment_funct(tmpdir) -> None: | ||
# language=python | ||
content = """foo,bar,buzz = f()""" | ||
|
||
with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase: | ||
file1 = codebase.get_file("test1.py") | ||
file2 = codebase.get_file("test2.py") | ||
file3 = codebase.get_file("test3.py") | ||
|
||
foo = file1.get_symbol("foo") | ||
foo.remove() | ||
codebase.commit() | ||
|
||
assert len(file1.symbols) == 3 | ||
statement = file1.symbols[0].parent | ||
assert len(statement.assignments) == 3 | ||
assert file1.source == """_,bar,buzz = f()""" | ||
bar = file2.get_symbol("bar") | ||
bar.remove() | ||
codebase.commit() | ||
assert len(file2.symbols) == 3 | ||
statement = file2.symbols[0].parent | ||
assert len(statement.assignments) == 3 | ||
assert file2.source == """foo,_,buzz = f()""" | ||
|
||
buzz = file3.get_symbol("buzz") | ||
buzz.remove() | ||
codebase.commit() | ||
|
||
assert len(file3.symbols) == 3 | ||
statement = file3.symbols[0].parent | ||
assert len(statement.assignments) == 3 | ||
assert file3.source == """foo,bar,_ = f()""" | ||
|
||
file1_bar = file1.get_symbol("bar") | ||
file1_buzz = file1.get_symbol("buzz") | ||
|
||
file1_bar.remove() | ||
file1_buzz.remove() | ||
codebase.commit() | ||
assert len(file1.symbols) == 0 | ||
assert file1.source == """""" | ||
|
||
|
||
def test_remove_unpacking_assignment_num(tmpdir) -> None: | ||
# language=python | ||
content = """foo,bar,buzz = (1, 2, 3)""" | ||
|
||
with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content}) as codebase: | ||
file1 = codebase.get_file("test1.py") | ||
|
||
foo = file1.get_symbol("foo") | ||
buzz = file1.get_symbol("buzz") | ||
|
||
foo.remove() | ||
buzz.remove() | ||
codebase.commit() | ||
|
||
assert len(file1.symbols) == 1 | ||
statement = file1.symbols[0].parent | ||
assert len(statement.assignments) == 1 | ||
assert file1.source == """bar = 2""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it'll be better to implement _smart_remove rather than overwriting remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to know what Assignment it's being called on and remove dissolves it into Value which does not preserve the distinction, if there is a better idea than override let me know