Skip to content

refactor: Split out file IO #404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 11, 2025
33 changes: 14 additions & 19 deletions src/codegen/sdk/codebase/codebase_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from enum import IntEnum, auto, unique
from functools import lru_cache
Expand All @@ -16,14 +15,14 @@
from codegen.sdk.codebase.config_parser import ConfigParser, get_config_parser_for_language
from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite
from codegen.sdk.codebase.flagging.flags import Flags
from codegen.sdk.codebase.io.file_io import FileIO
from codegen.sdk.codebase.transaction_manager import TransactionManager
from codegen.sdk.codebase.validation import get_edges, post_reset_validation
from codegen.sdk.core.autocommit import AutoCommit, commiter
from codegen.sdk.core.directory import Directory
from codegen.sdk.core.external.dependency_manager import DependencyManager, get_dependency_manager
from codegen.sdk.core.external.language_engine import LanguageEngine, get_language_engine
from codegen.sdk.enums import Edge, EdgeType, NodeType, ProgrammingLanguage
from codegen.sdk.extensions.io import write_changes
from codegen.sdk.extensions.sort import sort_editables
from codegen.sdk.extensions.utils import uncache_all
from codegen.sdk.typescript.external.ts_declassify.ts_declassify import TSDeclassify
Expand All @@ -37,6 +36,7 @@
from git import Commit as GitCommit

from codegen.git.repo_operator.repo_operator import RepoOperator
from codegen.sdk.codebase.io.io import IO
from codegen.sdk.codebase.node_classes.node_classes import NodeClasses
from codegen.sdk.core.dataclasses.usage import Usage
from codegen.sdk.core.expressions import Expression
Expand Down Expand Up @@ -92,7 +92,6 @@
pending_syncs: list[DiffLite] # Diffs that have been applied to disk, but not the graph (to be used for sync graph)
all_syncs: list[DiffLite] # All diffs that have been applied to the graph (to be used for graph reset)
_autocommit: AutoCommit
pending_files: set[SourceFile]
generation: int
parser: Parser[Expression]
synced_commit: GitCommit | None
Expand All @@ -110,6 +109,7 @@
session_options: SessionOptions = SessionOptions()
projects: list[ProjectConfig]
unapplied_diffs: list[DiffLite]
io: IO

def __init__(
self,
Expand All @@ -134,13 +134,14 @@

# =====[ __init__ attributes ]=====
self.projects = projects
self.io = FileIO()
context = projects[0]
self.node_classes = get_node_classes(context.programming_language)
self.config = config
self.repo_name = context.repo_operator.repo_name
self.repo_path = str(Path(context.repo_operator.repo_path).resolve())
self.codeowners_parser = context.repo_operator.codeowners_parser

Check failure on line 143 in src/codegen/sdk/codebase/codebase_context.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible types in assignment (expression has type "Callable[[], CodeOwners | None]", variable has type "CodeOwners | None") [assignment]
self.base_url = context.repo_operator.base_url

Check failure on line 144 in src/codegen/sdk/codebase/codebase_context.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible types in assignment (expression has type "Callable[[], str | None]", variable has type "str | None") [assignment]
# =====[ computed attributes ]=====
self.transaction_manager = TransactionManager()
self._autocommit = AutoCommit(self)
Expand All @@ -165,7 +166,6 @@
self.pending_syncs = []
self.all_syncs = []
self.unapplied_diffs = []
self.pending_files = set()
self.flags = Flags()

def __repr__(self):
Expand All @@ -183,7 +183,7 @@
syncs[SyncType.ADD].append(self.to_absolute(filepath))
logger.info(f"> Parsing {len(syncs[SyncType.ADD])} files in {self.projects[0].subdirectories or 'ALL'} subdirectories with {self.extensions} extensions")
self._process_diff_files(syncs, incremental=False)
files: list[SourceFile] = self.get_nodes(NodeType.FILE)

Check failure on line 186 in src/codegen/sdk/codebase/codebase_context.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible types in assignment (expression has type "list[Importable[Any]]", variable has type "list[SourceFile[Any, Any, Any, Any, Any, Any]]") [assignment]
logger.info(f"> Found {len(files)} files")
logger.info(f"> Found {len(self.nodes)} nodes and {len(self.edges)} edges")
if self.config.feature_flags.track_graph:
Expand Down Expand Up @@ -213,8 +213,8 @@
elif diff.change_type == ChangeType.Modified:
files_to_sync[filepath] = SyncType.REPARSE
elif diff.change_type == ChangeType.Renamed:
files_to_sync[diff.rename_from] = SyncType.DELETE

Check failure on line 216 in src/codegen/sdk/codebase/codebase_context.py

View workflow job for this annotation

GitHub Actions / mypy

error: Invalid index type "Path | None" for "dict[Path, SyncType]"; expected type "Path" [index]
files_to_sync[diff.rename_to] = SyncType.ADD

Check failure on line 217 in src/codegen/sdk/codebase/codebase_context.py

View workflow job for this annotation

GitHub Actions / mypy

error: Invalid index type "Path | None" for "dict[Path, SyncType]"; expected type "Path" [index]
elif diff.change_type == ChangeType.Removed:
files_to_sync[filepath] = SyncType.DELETE
else:
Expand Down Expand Up @@ -251,15 +251,21 @@
files_to_write.append((sync.path, sync.old_content))
modified_files.add(sync.path)
elif sync.change_type == ChangeType.Renamed:
files_to_write.append((sync.rename_from, sync.old_content))

Check failure on line 254 in src/codegen/sdk/codebase/codebase_context.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "append" of "list" has incompatible type "tuple[Path | None, bytes | None]"; expected "tuple[Path, bytes | None]" [arg-type]
files_to_remove.append(sync.rename_to)
modified_files.add(sync.rename_from)

Check failure on line 256 in src/codegen/sdk/codebase/codebase_context.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "add" of "set" has incompatible type "Path | None"; expected "Path" [arg-type]
modified_files.add(sync.rename_to)

Check failure on line 257 in src/codegen/sdk/codebase/codebase_context.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "add" of "set" has incompatible type "Path | None"; expected "Path" [arg-type]
elif sync.change_type == ChangeType.Added:
files_to_remove.append(sync.path)
modified_files.add(sync.path)
logger.info(f"Writing {len(files_to_write)} files to disk and removing {len(files_to_remove)} files")
write_changes(files_to_remove, files_to_write)
for file in files_to_remove:
self.io.delete_file(file)
to_save = set()
for file, content in files_to_write:
self.io.write_file(file, content)
to_save.add(file)
self.io.save_files(to_save)

@stopwatch
def reset_codebase(self) -> None:
Expand All @@ -270,7 +276,7 @@
def undo_applied_diffs(self) -> None:
self.transaction_manager.clear_transactions()
self.reset_codebase()
self.check_changes()
self.io.check_changes()
self.pending_syncs.clear() # Discard pending changes
if len(self.all_syncs) > 0:
logger.info(f"Unapplying {len(self.all_syncs)} diffs to graph. Current graph commit: {self.synced_commit}")
Expand Down Expand Up @@ -432,7 +438,7 @@

# Step 5: Add new files as nodes to graph (does not yet add edges)
for filepath in files_to_sync[SyncType.ADD]:
content = filepath.read_text(errors="ignore")
content = self.io.read_text(filepath)
# TODO: this is wrong with context changes
if filepath.suffix in self.extensions:
file_cls = self.node_classes.file_cls
Expand Down Expand Up @@ -634,17 +640,6 @@
continue
self._graph.remove_edge_from_index(edge)

def check_changes(self) -> None:
for file in self.pending_files:
file.check_changes()
self.pending_files.clear()

def write_files(self, files: set[Path] | None = None) -> None:
to_write = set(filter(lambda f: f.filepath in files, self.pending_files)) if files is not None else self.pending_files
with ThreadPoolExecutor() as exec:
exec.map(lambda f: f.write_pending_content(), to_write)
self.pending_files.difference_update(to_write)

@lru_cache(maxsize=10000)
def to_absolute(self, filepath: PathLike | str) -> Path:
path = Path(filepath)
Expand Down Expand Up @@ -684,7 +679,7 @@

# Write files if requested
if sync_file:
self.write_files(files)
self.io.save_files(files)

# Sync the graph if requested
if sync_graph and len(self.pending_syncs) > 0:
Expand Down
51 changes: 51 additions & 0 deletions src/codegen/sdk/codebase/io/file_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

from codegen.sdk.codebase.io.io import IO, BadWriteError

logger = logging.getLogger(__name__)


class FileIO(IO):
"""IO implementation that writes files to disk, and tracks pending changes."""

files: dict[Path, bytes]

def __init__(self):
self.files = {}

def write_bytes(self, path: Path, content: bytes) -> None:
self.files[path] = content

def read_bytes(self, path: Path) -> bytes:
if path in self.files:
return self.files[path]
else:
return path.read_bytes()

def save_files(self, files: set[Path] | None = None) -> None:
to_save = set(filter(lambda f: f in files, self.files)) if files is not None else self.files.keys()
with ThreadPoolExecutor() as exec:
exec.map(lambda path: path.write_bytes(self.files[path]), to_save)
if files is None:
self.files.clear()
else:
for path in to_save:
del self.files[path]

def check_changes(self) -> None:
if self.files:
logger.error(BadWriteError("Directly called file write without calling commit_transactions"))
self.files.clear()

def delete_file(self, path: Path) -> None:
self.untrack_file(path)
if path.exists():
path.unlink()

def untrack_file(self, path: Path) -> None:
self.files.pop(path, None)

def file_exists(self, path: Path) -> bool:
return path.exists()
50 changes: 50 additions & 0 deletions src/codegen/sdk/codebase/io/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
from pathlib import Path


class BadWriteError(Exception):
pass


class IO(ABC):
def write_file(self, path: Path, content: str | bytes | None) -> None:
if content is None:
self.untrack_file(path)
elif isinstance(content, str):
self.write_text(path, content)
else:
self.write_bytes(path, content)

def write_text(self, path: Path, content: str) -> None:
self.write_bytes(path, content.encode("utf-8"))

@abstractmethod
def untrack_file(self, path: Path) -> None:
pass

@abstractmethod
def write_bytes(self, path: Path, content: bytes) -> None:
pass

@abstractmethod
def read_bytes(self, path: Path) -> bytes:
pass

def read_text(self, path: Path) -> str:
return self.read_bytes(path).decode("utf-8")

@abstractmethod
def save_files(self, files: set[Path] | None = None) -> None:
pass

@abstractmethod
def check_changes(self) -> None:
pass

@abstractmethod
def delete_file(self, path: Path) -> None:
pass

@abstractmethod
def file_exists(self, path: Path) -> bool:
pass
6 changes: 2 additions & 4 deletions src/codegen/sdk/codebase/transactions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from collections.abc import Callable
from difflib import unified_diff
from enum import IntEnum
Expand Down Expand Up @@ -158,7 +157,7 @@
self.exec_func = exec_func

def _generate_new_content_bytes(self) -> bytes:
new_bytes = bytes(self.new_content, encoding="utf-8")

Check failure on line 160 in src/codegen/sdk/codebase/transactions.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "bytes" has incompatible type "str | None"; expected "str" [arg-type]
content_bytes = self.file.content_bytes
head = content_bytes[: self.insert_byte]
tail = content_bytes[self.insert_byte :]
Expand All @@ -178,7 +177,7 @@
def diff_str(self) -> str:
"""Human-readable string representation of the change"""
diff = "".join(unified_diff(self.file.content.splitlines(True), self._generate_new_content_bytes().decode("utf-8").splitlines(True)))
return f"Insert {len(self.new_content)} bytes at bytes ({self.start_byte}, {self.end_byte})\n{diff}"

Check failure on line 180 in src/codegen/sdk/codebase/transactions.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized" [arg-type]


class EditTransaction(Transaction):
Expand Down Expand Up @@ -267,7 +266,7 @@

def execute(self) -> None:
"""Renames the file"""
self.file.write_pending_content()
self.file.ctx.io.save_files({self.file.path})
self.file_path.rename(self.new_file_path)

def get_diff(self) -> DiffLite:
Expand All @@ -292,8 +291,7 @@

def execute(self) -> None:
"""Removes the file"""
os.remove(self.file_path)
self.file._pending_content_bytes = None
self.file.ctx.io.delete_file(self.file.path)

def get_diff(self) -> DiffLite:
"""Gets the diff produced by this transaction"""
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/sdk/core/codebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,17 +481,17 @@ def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool =

def get_file_from_path(path: Path) -> File | None:
try:
return File.from_content(path, path.read_text(), self.ctx, sync=False)
return File.from_content(path, self.ctx.io.read_text(path), self.ctx, sync=False)
except UnicodeDecodeError:
# Handle when file is a binary file
return File.from_content(path, path.read_bytes(), self.ctx, sync=False, binary=True)
return File.from_content(path, self.ctx.io.read_bytes(path), self.ctx, sync=False, binary=True)

# Try to get the file from the graph first
file = self.ctx.get_file(filepath, ignore_case=ignore_case)
if file is not None:
return file
absolute_path = self.ctx.to_absolute(filepath)
if absolute_path.exists():
if self.ctx.io.file_exists(absolute_path):
return get_file_from_path(absolute_path)
elif ignore_case:
parent = absolute_path.parent
Expand Down
49 changes: 14 additions & 35 deletions src/codegen/sdk/core/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar, override

from tree_sitter import Node as TSNode
from typing_extensions import deprecated

from codegen.sdk._proxy import proxy_property
from codegen.sdk.codebase.codebase_context import CodebaseContext
Expand Down Expand Up @@ -45,10 +46,6 @@
logger = logging.getLogger(__name__)


class BadWriteError(Exception):
pass


@apidoc
class File(Editable[None]):
"""Represents a generic file.
Expand All @@ -66,7 +63,6 @@
file_path: str
path: Path
node_type: Literal[NodeType.FILE] = NodeType.FILE
_pending_content_bytes: bytes | None = None
_directory: Directory | None
_pending_imports: set[str]
_binary: bool = False
Expand Down Expand Up @@ -117,10 +113,8 @@
if not path.exists():
update_graph = True
path.parent.mkdir(parents=True, exist_ok=True)
if not binary:
path.write_text(content)
else:
path.write_bytes(content)
ctx.io.write_file(path, content)
ctx.io.save_files({path})

Check warning on line 117 in src/codegen/sdk/core/file.py

View check run for this annotation

Codecov / codecov/patch

src/codegen/sdk/core/file.py#L116-L117

Added lines #L116 - L117 were not covered by tests

new_file = cls(filepath, ctx, ts_node=None, binary=binary)
return new_file
Expand All @@ -133,10 +127,7 @@

TODO: move rest of graph sitter to operate in bytes to prevent multi byte character issues?
"""
# Check against None due to possibility of empty byte
if self._pending_content_bytes is None:
return self.path.read_bytes()
return self._pending_content_bytes
return self.ctx.io.read_bytes(self.path)

@property
@reader
Expand All @@ -162,31 +153,18 @@

@noapidoc
def write(self, content: str | bytes, to_disk: bool = False) -> None:
"""Writes string contents to the file."""
self.write_bytes(content.encode("utf-8") if isinstance(content, str) else content, to_disk=to_disk)

@noapidoc
def write_bytes(self, content_bytes: bytes, to_disk: bool = False) -> None:
self._pending_content_bytes = content_bytes
self.ctx.pending_files.add(self)
"""Writes contents to the file."""
self.ctx.io.write_file(self.path, content)
if to_disk:
self.write_pending_content()
self.ctx.io.save_files({self.path})
if self.ts_node.start_byte == self.ts_node.end_byte:
# TS didn't parse anything, register a write to make sure the transaction manager can restore the file later.
self.edit("")

@noapidoc
def write_pending_content(self) -> None:
if self._pending_content_bytes is not None:
self.path.write_bytes(self._pending_content_bytes)
self._pending_content_bytes = None
logger.debug("Finished write_pending_content")

@noapidoc
@writer
def check_changes(self) -> None:
if self._pending_content_bytes is not None:
logger.error(BadWriteError("Directly called file write without calling commit_transactions"))
@deprecated("Use write instead")
def write_bytes(self, content_bytes: bytes, to_disk: bool = False) -> None:
self.write(content_bytes, to_disk=to_disk)

@property
@reader
Expand Down Expand Up @@ -272,7 +250,7 @@
None
"""
self.transaction_manager.add_file_remove_transaction(self)
self._pending_content_bytes = None
self.ctx.io.write_file(self.path, None)

@property
def filepath(self) -> str:
Expand Down Expand Up @@ -596,10 +574,11 @@
return None

update_graph = False
if not path.exists():
if not ctx.io.file_exists(path):
update_graph = True
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content)
ctx.io.write_file(path, content)
ctx.io.save_files({path})

if update_graph and sync:
ctx.add_single_file(path)
Expand Down
12 changes: 0 additions & 12 deletions src/codegen/sdk/extensions/io.pyx

This file was deleted.

Loading