Skip to content

[CG-8871] fix: removing dead code removes code that should stay #355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions src/codegen/sdk/python/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

from typing import TYPE_CHECKING

from codegen.sdk.codebase.transactions import RemoveTransaction, TransactionPriority
from codegen.sdk.core.assignment import Assignment
from codegen.sdk.core.autocommit.decorators import remover
from codegen.sdk.core.expressions.multi_expression import MultiExpression
from codegen.sdk.core.symbol_groups.collection import Collection
from codegen.sdk.extensions.autocommit import reader
from codegen.sdk.python.symbol import PySymbol
from codegen.sdk.python.symbol_groups.comment_group import PyCommentGroup
Expand Down Expand Up @@ -33,8 +36,8 @@

left_node = ts_node.child_by_field_name("left")
right_node = ts_node.child_by_field_name("right")
assignments = cls._from_left_and_right_nodes(ts_node, file_node_id, ctx, parent, left_node, right_node)

Check failure on line 39 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 5 to "_from_left_and_right_nodes" of "Assignment" has incompatible type "Node | None"; expected "Node" [arg-type]

Check failure on line 39 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 6 to "_from_left_and_right_nodes" of "Assignment" has incompatible type "Node | None"; expected "Node" [arg-type]
return MultiExpression(ts_node, file_node_id, ctx, parent, assignments)

Check failure on line 40 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 5 to "MultiExpression" has incompatible type "list[Assignment[Any]]"; expected "list[PyAssignment]" [arg-type]

@classmethod
def from_named_expression(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyAssignmentStatement) -> MultiExpression[PyAssignmentStatement, PyAssignment]:
Expand All @@ -60,8 +63,8 @@

left_node = ts_node.child_by_field_name("name")
right_node = ts_node.child_by_field_name("value")
assignments = cls._from_left_and_right_nodes(ts_node, file_node_id, ctx, parent, left_node, right_node)

Check failure on line 66 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 5 to "_from_left_and_right_nodes" of "Assignment" has incompatible type "Node | None"; expected "Node" [arg-type]

Check failure on line 66 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 6 to "_from_left_and_right_nodes" of "Assignment" has incompatible type "Node | None"; expected "Node" [arg-type]
return MultiExpression(ts_node, file_node_id, ctx, parent, assignments)

Check failure on line 67 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 5 to "MultiExpression" has incompatible type "list[Assignment[Any]]"; expected "list[PyAssignment]" [arg-type]

@property
@reader
Expand Down Expand Up @@ -96,3 +99,70 @@
"""
# HACK: This is a temporary solution until comments are fixed
return PyCommentGroup.from_symbol_inline_comments(self, self.ts_node.parent)

@remover
def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it'll be better to implement _smart_remove rather than overwriting remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to know what Assignment it's being called on and remove dissolves it into Value which does not preserve the distinction, if there is a better idea than override let me know

"""Deletes this assignment and its related extended nodes (e.g. decorators, comments).


Removes the current node and its extended nodes (e.g. decorators, comments) from the codebase.
After removing the node, it handles cleanup of any surrounding formatting based on the context.

Args:
delete_formatting (bool): Whether to delete surrounding whitespace and formatting. Defaults to True.
priority (int): Priority of the removal transaction. Higher priority transactions are executed first. Defaults to 0.
dedupe (bool): Whether to deduplicate removal transactions at the same location. Defaults to True.

Returns:
None
"""
if getattr(self.parent, "assignments", None) and len(self.parent.assignments) > 1:

Check failure on line 119 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: "Statement[CodeBlock[PyAssignmentStatement, Any]]" has no attribute "assignments" [attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do isinstance here?

# Unpacking assignments
name = self.get_name()
if isinstance(self.value, Collection):
# Tuples
transaction_count = [
any(
self.transaction_manager.get_transactions_at_range(
self.file.path, start_byte=asgnmt.get_name().start_byte, end_byte=asgnmt.get_name().end_byte, transaction_order=TransactionPriority.Remove
)
)
for asgnmt in self.parent.assignments
].count(True)
# Check for existing transactions
if transaction_count < len(self.parent.assignments) - 1:
idx = self.parent.left.index(name)
value = self.value[idx]
removal_queue_values = getattr(self.parent, "removal_queue", [])
self.parent.removal_queue = removal_queue_values
removal_queue_values.append(str(value))
if len(self.value) - transaction_count == 2:
remainder = str(next(x for x in self.value if x not in removal_queue_values))
r_t = RemoveTransaction(self.value.start_byte, self.value.end_byte, self.file, priority=priority)
self.transaction_manager.add_transaction(r_t)
self.value.insert_at(self.value.start_byte, remainder, priority=priority)
else:
value.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe)
name.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe)
return
else:
transaction_count = [
any(
self.transaction_manager.get_transactions_at_range(
self.file.path, start_byte=asgnmt.get_name().start_byte, end_byte=asgnmt.get_name().end_byte, transaction_order=TransactionPriority.Edit
)
)
for asgnmt in self.parent.assignments

Check failure on line 155 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: "Statement[CodeBlock[PyAssignmentStatement, Any]]" has no attribute "assignments" [attr-defined]
].count(True)
throwaway = [asgnmt.name == "_" for asgnmt in self.parent.assignments].count(True)

Check failure on line 157 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: "Statement[CodeBlock[PyAssignmentStatement, Any]]" has no attribute "assignments" [attr-defined]
if transaction_count + throwaway < len(self.parent.assignments) - 1:

Check failure on line 158 in src/codegen/sdk/python/assignment.py

View workflow job for this annotation

GitHub Actions / mypy

error: "Statement[CodeBlock[PyAssignmentStatement, Any]]" has no attribute "assignments" [attr-defined]
name.edit("_", priority=priority, dedupe=dedupe)
return
if getattr(self.parent, "removal_queue", None):
for node in self.extended_nodes:
transactions = self.transaction_manager.get_transactions_at_range(self.file.path, start_byte=node.start_byte, end_byte=node.end_byte)
for transaction in transactions:
self.transaction_manager.queued_transactions[self.file.path].remove(transaction)

Check warning on line 165 in src/codegen/sdk/python/assignment.py

View check run for this annotation

Codecov / codecov/patch

src/codegen/sdk/python/assignment.py#L162-L165

Added lines #L162 - L165 were not covered by tests

for node in self.extended_nodes:
node._remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe)
25 changes: 25 additions & 0 deletions src/codegen/sdk/python/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,28 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P
ret[file.name] = file
return ret
return super().valid_import_names

def get_node_from_wildcard_chain(self, symbol_name: str):
node = None
if node := self.get_node_by_name(symbol_name):
return node

if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}:
for wildcard_import in wildcard_imports:
if imp_resolution := wildcard_import.resolve_import():
node = imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name)

return node

def get_node_wildcard_resolves_for(self, symbol_name: str):
node = None
if node := self.get_node_by_name(symbol_name):
return node

if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}:
for wildcard_import in wildcard_imports:
if imp_resolution := wildcard_import.resolve_import():
if imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name):
node = wildcard_import

return node
46 changes: 43 additions & 3 deletions src/codegen/sdk/python/import_resolution.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING
from collections.abc import Generator
from typing import TYPE_CHECKING, Self, override

from codegen.sdk.core.autocommit import reader
from codegen.sdk.core.expressions import Name
from codegen.sdk.core.import_resolution import ExternalImportResolver, Import, ImportResolution
from codegen.sdk.enums import ImportType, NodeType
from codegen.sdk.extensions.resolution import ResolutionStack
from codegen.shared.decorators.docs import noapidoc, py_apidoc

if TYPE_CHECKING:
from collections.abc import Generator

from tree_sitter import Node as TSNode

from codegen.sdk.codebase.codebase_context import CodebaseContext
Expand All @@ -28,6 +32,10 @@
class PyImport(Import["PyFile"]):
"""Extends Import for Python codebases."""

def __init__(self, ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type=ImportType.UNKNOWN):
super().__init__(ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type)
self.requesting_names = set()

@reader
def is_module_import(self) -> bool:
"""Determines if the import is a module-level or wildcard import.
Expand Down Expand Up @@ -117,13 +125,13 @@
filepath = module_source.replace(".", "/") + ".py"
filepath = os.path.join(base_path, filepath)
if file := self.ctx.get_file(filepath):
symbol = file.get_node_by_name(symbol_name)
symbol = file.get_node_wildcard_resolves_for(symbol_name)
return ImportResolution(from_file=file, symbol=symbol)

# =====[ Check if `module/__init__.py` file exists in the graph ]=====
filepath = filepath.replace(".py", "/__init__.py")
if from_file := self.ctx.get_file(filepath):
symbol = from_file.get_node_by_name(symbol_name)
symbol = from_file.get_node_wildcard_resolves_for(symbol_name)
return ImportResolution(from_file=from_file, symbol=symbol)

# =====[ Case: Can't resolve the import ]=====
Expand All @@ -133,6 +141,11 @@
if base_path == "src":
# Try "test" next
return self.resolve_import(base_path="test", add_module_name=add_module_name)
if base_path == "test" and module_source:
# Try to resolve assuming package nested in repo
possible_package_base_path = module_source.split(".")[0]
if possible_package_base_path not in ("test", "src"):
return self.resolve_import(base_path=possible_package_base_path, add_module_name=add_module_name)

# if not G_override:
# for resolver in ctx.import_resolvers:
Expand Down Expand Up @@ -232,6 +245,33 @@
imports.append(imp)
return imports

@reader
@noapidoc
@override
def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the right approach. It will create duplicated edges. For example

def foo():
     ...
def bar():
     ...
from filea import *
foo() # will be a usage of foo and bar

For some context, we regular imports are resolved as follows.

  1. We add the import to valid_import_names
  2. When resolving, we check valid_symbol_names for the name
  3. We call _resolved_types on the import

Since it's impossible to know which symbol you're importing in a wildcard, wildcards follow the following pattern

  1. We check the imported file and get all valid_import_names (this is the same as valid_symbol_names in python) from that file
  2. For each imported symbol, we created a wilcardImport for that name and put it in the valid imports
  3. When resolving, we call _resolved_types on the wildcard import. We create an edge through the import but we point it at the specific imported symbol.

For example, in the above case, we'd call _resolved_types on the wildcard for foo. But it'll create edges from the import to foo and from the call site to foo

Copy link
Contributor Author

@tkfoss tkfoss Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seemed to work as expected, I ve added additional UTs, have I missed anything?

"""Resolve the types used by this import."""
ix_seen = set()

aliased = self.is_aliased_import()
if imported := self._imported_symbol(resolve_exports=True):
if getattr(imported, "is_wildcard_import", False):
imported.set_requesting_names(self)
yield from self.with_resolution_frame(imported, direct=False, aliased=aliased)
else:
yield ResolutionStack(self, aliased=aliased)

if self.is_wildcard_import():
for name, wildcard_import in self.names:
if name in self.requesting_names:
yield from [frame.parent_frame for frame in wildcard_import.resolved_type_frames]

@noapidoc
def set_requesting_names(self, requester: PyImport):
if requester.is_wildcard_import():
self.requesting_names.update(requester.requesting_names)

Check warning on line 271 in src/codegen/sdk/python/import_resolution.py

View check run for this annotation

Codecov / codecov/patch

src/codegen/sdk/python/import_resolution.py#L271

Added line #L271 was not covered by tests
else:
self.requesting_names.add(requester.name)

@property
@reader
def import_specifier(self) -> Editable:
Expand Down
116 changes: 116 additions & 0 deletions tests/unit/codegen/sdk/python/expressions/test_unpacking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from codegen.sdk.codebase.factory.get_session import get_codebase_session


def test_remove_unpacking_assignment(tmpdir) -> None:
# language=python
content = """foo,bar,buzz = (a, b, c)"""

with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase:
file1 = codebase.get_file("test1.py")
file2 = codebase.get_file("test2.py")
file3 = codebase.get_file("test3.py")

foo = file1.get_symbol("foo")
foo.remove()
codebase.commit()

assert len(file1.symbols) == 2
statement = file1.symbols[0].parent
assert len(statement.assignments) == 2
assert len(statement.value) == 2
assert file1.source == """bar,buzz = (b, c)"""
bar = file2.get_symbol("bar")
bar.remove()
codebase.commit()
assert len(file2.symbols) == 2
statement = file2.symbols[0].parent
assert len(statement.assignments) == 2
assert len(statement.value) == 2
assert file2.source == """foo,buzz = (a, c)"""

buzz = file3.get_symbol("buzz")
buzz.remove()
codebase.commit()

assert len(file3.symbols) == 2
statement = file3.symbols[0].parent
assert len(statement.assignments) == 2
assert len(statement.value) == 2
assert file3.source == """foo,bar = (a, b)"""

file1_bar = file1.get_symbol("bar")

file1_bar.remove()
codebase.commit()
assert file1.source == """buzz = c"""

file1_buzz = file1.get_symbol("buzz")
file1_buzz.remove()

codebase.commit()
assert len(file1.symbols) == 0
assert file1.source == """"""


def test_remove_unpacking_assignment_funct(tmpdir) -> None:
# language=python
content = """foo,bar,buzz = f()"""

with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase:
file1 = codebase.get_file("test1.py")
file2 = codebase.get_file("test2.py")
file3 = codebase.get_file("test3.py")

foo = file1.get_symbol("foo")
foo.remove()
codebase.commit()

assert len(file1.symbols) == 3
statement = file1.symbols[0].parent
assert len(statement.assignments) == 3
assert file1.source == """_,bar,buzz = f()"""
bar = file2.get_symbol("bar")
bar.remove()
codebase.commit()
assert len(file2.symbols) == 3
statement = file2.symbols[0].parent
assert len(statement.assignments) == 3
assert file2.source == """foo,_,buzz = f()"""

buzz = file3.get_symbol("buzz")
buzz.remove()
codebase.commit()

assert len(file3.symbols) == 3
statement = file3.symbols[0].parent
assert len(statement.assignments) == 3
assert file3.source == """foo,bar,_ = f()"""

file1_bar = file1.get_symbol("bar")
file1_buzz = file1.get_symbol("buzz")

file1_bar.remove()
file1_buzz.remove()
codebase.commit()
assert len(file1.symbols) == 0
assert file1.source == """"""


def test_remove_unpacking_assignment_num(tmpdir) -> None:
# language=python
content = """foo,bar,buzz = (1, 2, 3)"""

with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content}) as codebase:
file1 = codebase.get_file("test1.py")

foo = file1.get_symbol("foo")
buzz = file1.get_symbol("buzz")

foo.remove()
buzz.remove()
codebase.commit()

assert len(file1.symbols) == 1
statement = file1.symbols[0].parent
assert len(statement.assignments) == 1
assert file1.source == """bar = 2"""
Loading
Loading