Skip to content

Add runner module #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ dependencies = [
"PyGithub==2.5.0",
"GitPython==3.1.44",
"psutil>=5.8.0",
"fastapi[standard]<1.0.0,>=0.115.2",
"starlette<1.0.0,>=0.16.0",
]
license = {file = "LICENSE"}
classifiers = [
Expand Down
53 changes: 53 additions & 0 deletions src/codegen/git/utils/branch_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging
from enum import StrEnum

from git.remote import Remote

from codegen.git.configs.constants import HIGHSIDE_REMOTE_NAME
from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator
from codegen.git.schemas.enums import FetchResult
from codegen.git.schemas.github import GithubType
from codegen.git.utils.clone_url import get_authenticated_clone_url_for_repo_config
from codegen.utils.performance.stopwatch_utils import stopwatch

logger = logging.getLogger(__name__)


class BranchSyncResult(StrEnum):
SUCCESS = "SUCCESS"
BRANCH_NOT_FOUND = "BRANCH_NOT_FOUND"
SKIP = "SKIP"


def get_highside_origin(op: RemoteRepoOperator) -> Remote:
remote_url = get_authenticated_clone_url_for_repo_config(op.repo_config, github_type=GithubType.Github)

if HIGHSIDE_REMOTE_NAME in op.git_cli.remotes:
highside_origin = op.git_cli.remote(HIGHSIDE_REMOTE_NAME)
highside_origin.set_url(remote_url)
else:
highside_origin = op.git_cli.create_remote(HIGHSIDE_REMOTE_NAME, remote_url)
return highside_origin


@stopwatch
def fetch_highside_branch(op: RemoteRepoOperator, branch_name: str) -> FetchResult:
"""Checks out a a branch from highside origin"""
# Step 1: create highside origin
remote_url = get_authenticated_clone_url_for_repo_config(repo=op.repo_config, github_type=GithubType.Github)
op.create_remote(HIGHSIDE_REMOTE_NAME, remote_url)

# Step 2: fetch the branch from highside
res = op.fetch_remote(HIGHSIDE_REMOTE_NAME, refspec=branch_name)
if res == FetchResult.REFSPEC_NOT_FOUND:
logger.warning(f"Branch: {branch_name} not found in highside. Skipping fetch.")
return FetchResult.REFSPEC_NOT_FOUND

# Step 3: checkout (or update existing) local branch that tracks highside remote
if op.is_branch_checked_out(branch_name):
# update currently checked out branch to match the latest highside branch
op.git_cli.git.reset("--hard", f"{HIGHSIDE_REMOTE_NAME}/{branch_name}")
else:
# create a new local branch that tracks the remote highside branch
op.git_cli.create_head(branch_name, commit=f"{HIGHSIDE_REMOTE_NAME}/{branch_name}", force=True)
return FetchResult.SUCCESS
Empty file.
9 changes: 9 additions & 0 deletions src/codegen/runner/constants/envvars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Environment variables used in the sandbox."""

# ==== [ Environment variable names ] ====
CUSTOMER_REPO_ID = "CUSTOMER_REPO_ID"
FEATURE_FLAGS_BASE64 = "FEATURE_FLAGS_BASE64"
REPO_CONFIG_BASE64 = "REPO_CONFIG_BASE64"
LOWSIDE_TOKEN = "LOWSIDE_TOKEN"
HIGHSIDE_TOKEN = "HIGHSIDE_TOKEN"
IS_SANDBOX = "IS_SANDBOX"
94 changes: 94 additions & 0 deletions src/codegen/runner/diff/get_raw_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import io
import logging

from unidiff import LINE_TYPE_CONTEXT, Hunk, PatchedFile, PatchSet
from unidiff.patch import Line

from codegen.sdk.core.codebase import Codebase

logger = logging.getLogger(__name__)


def append_flag(file: PatchedFile, append_at: int, line_no: int, codebase: Codebase) -> None:
added_hunk = Hunk(
src_start=line_no,
src_len=1,
tgt_start=line_no,
tgt_len=1,
)
line = codebase.get_file(file.path).content.split("\n")[line_no - 1]
added_hunk.append(Line(f"{line}\n", line_type=LINE_TYPE_CONTEXT))
file.insert(append_at, added_hunk)


def patch_to_limited_diff_string(patch, codebase: Codebase, max_lines=10000):
diff_lines = []
total_lines = 0

# Add flags that are not in the diff
filenames = [patched_file.path for patched_file in patch]
flags_not_in_diff = list(filter(lambda flag: flag.symbol.filepath not in filenames, codebase.G.flags._flags))

for flag in flags_not_in_diff:
filename = flag.symbol.filepath
patched_file = PatchedFile(
patch_info=f"diff --git a/{filename} b/{filename}\n",
source=f"a/{filename}",
target=f"b/{filename}",
)
patch.append(patched_file)

for patched_file in patch:
filtered_flags = filter(lambda flag: flag.symbol.filepath == patched_file.path, codebase.G.flags._flags)
sorted_flags = list(map(lambda flag: flag.symbol.start_point.row + 1, filtered_flags))
sorted_flags.sort()

for flag in sorted_flags:
is_in_diff = False

for i, hunk in enumerate(patched_file):
contains_flag = hunk.source_start <= flag <= hunk.source_start + hunk.source_length

if contains_flag:
is_in_diff = True
break

is_after_flag = hunk.source_start > flag

if is_after_flag:
is_in_diff = True
append_flag(patched_file, i, flag, codebase)
break

if not is_in_diff:
append_flag(patched_file, len(patched_file), flag, codebase)

# Add file header
raw_diff = str(patched_file)
diff_length = len(raw_diff.splitlines())

total_lines += diff_length
diff_lines.append(raw_diff)

if total_lines >= max_lines:
break

return "\n".join(diff_lines)


def get_raw_diff(codebase: Codebase, base: str = "HEAD", max_lines: int = 10000) -> str:
raw_diff = codebase.get_diff(base)
patch_set = PatchSet(io.StringIO(raw_diff))

raw_diff_length = len(raw_diff.split("\n"))
logger.info(f"Truncating diff (total: {raw_diff_length}) to {max_lines} lines ...")
raw_diff_trunc = patch_to_limited_diff_string(patch=patch_set, max_lines=max_lines, codebase=codebase)

return raw_diff_trunc


def get_filenames_from_diff(diff: str) -> list[str]:
patch_set = PatchSet(io.StringIO(diff))
filenames = [patched_file.path for patched_file in patch_set]

return filenames
162 changes: 162 additions & 0 deletions src/codegen/runner/diff/syntax_highlight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import io
import json
import logging
import os
import select
import subprocess
import time

from unidiff import PatchedFile, PatchSet

from codegen.utils.performance.stopwatch_utils import stopwatch

logger = logging.getLogger(__name__)

HIGHLIGHTED_DIFF_FILENAME = "highlighted_diff.json"


@stopwatch
def syntax_highlight_modified_files(codebase, raw_diff: str, flags: list[dict]) -> str:
modified_files = PatchSet(io.StringIO(raw_diff))
highlighted_files = {}
highlighted_diff_files = {}

# TODO: refactor this
with subprocess.Popen(
". ~/.bashrc > /dev/null && nvm use > /dev/null && yarn run --silent highlight",
shell=True,
cwd="/codegen/codegen-frontend/app/modules/syntaxHighlight",
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
universal_newlines=True,
) as highlighter:
poll = select.poll()
poll.register(highlighter.stdout, select.POLLIN)

for file in modified_files:
filename = file.path
modified_filename = file.target_file
highlighted_files[filename] = (
"" if file.is_removed_file else _highlight_file(highlighter, poll, modified_filename if not modified_filename.startswith("b/") else modified_filename[2:], flags)
)

codebase.stash_changes()

for file in modified_files:
filename = file.path
original_filename = file.source_file
original = "" if file.is_added_file else _highlight_file(highlighter, poll, original_filename if not original_filename.startswith("a/") else original_filename[2:], flags)
modified = highlighted_files[filename]
highlighted_hunks = _construct_diff_highlight(codebase, original.splitlines(), modified.splitlines(), file)
highlighted_diff_files[filename] = highlighted_hunks

try:
codebase.restore_stashed_changes()
except Exception as e:
# This can happen if there are no changes stashed in the first place
logger.warning(f"Error restoring stashed changes: {e}")

_, err = highlighter.communicate()
returncode = highlighter.returncode

if err:
logger.error(f"Highlighter exited with error: {err}")

if returncode != 0:
raise Exception(f"Highlighter exited with code {returncode}")

highlighted_diff = json.dumps(highlighted_diff_files)
logger.info(f"Generated highlighted diff (size={len(highlighted_diff)})")
return highlighted_diff


@stopwatch
def _highlight_file(highlighter: subprocess.Popen[str], poll: select.poll, filename: str, flags: list[dict]):
stdin_input = {
"file": f"{os.getcwd()}/{filename}",
"flags": list(filter(lambda flag: flag["filepath"] == filename, flags)),
}
stdin_input = json.dumps(stdin_input)

logger.info(f"> Highlighting {filename}...")
highlighter.stdin.write(f"{stdin_input}\n")
highlighter.stdin.flush()
highlighted = ""

while True:
# if monotonic.monotonic() > timeout_at:
# raise Exception("Syntax highlighter timed out")
#
# poll_result = poll.poll(0.01)
#
# if not poll_result:
# continue

# TODO: this can deadlock in case the subprocess does not write a newline
line = highlighter.stdout.readline()

if not line:
time.sleep(0.01)

if line == "\x03\n":
break

highlighted += line

return highlighted


def _construct_diff_highlight(codebase, source: list[str], target: list[str], patched_file: PatchedFile) -> list:
original_lines = 0
modified_lines = 0
full_file = ""
full_file_lines = 0
highlighted_hunks = []

for hunk in patched_file:
hunk_lines = ""

while original_lines < (hunk.source_start - 1):
full_file += f" {source[original_lines]}\n"
full_file_lines += 1
original_lines += 1
modified_lines += 1

for line in hunk:
if line.is_removed:
full_file += f"-{source[original_lines]}\n"
hunk_lines += f"-{source[original_lines]}\n"
original_lines += 1
full_file_lines += 1
elif line.is_added:
full_file += f"+{target[modified_lines]}\n"
hunk_lines += f"+{target[modified_lines]}\n"
modified_lines += 1
full_file_lines += 1
else:
if len(source) > original_lines:
full_file += f" {source[original_lines]}\n"
hunk_lines += f" {source[original_lines]}\n"
elif len(target) > modified_lines:
full_file += f" {target[modified_lines]}\n"
hunk_lines += f" {target[modified_lines]}\n"
else:
logger.warning(f"Lines {original_lines}/{modified_lines} not found in {patched_file.path} in {codebase.current_commit.hexsha}: {line}")
original_lines += 1
modified_lines += 1
full_file_lines += 1

if hunk_lines.endswith("\n"):
hunk_lines = hunk_lines[:-1]

highlighted_hunks.append({"lines": hunk_lines, "starts_at": full_file_lines - len(hunk), "ends_at": full_file_lines - 1})

if original_lines < len(source):
full_file += "\n ".join(source[original_lines:])

# TODO: we know the file length so we can add a property to diff and determine if we can expand down even if we haven't loaded the entire file on FE yet

return highlighted_hunks
7 changes: 7 additions & 0 deletions src/codegen/runner/enums/warmup_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import StrEnum


class WarmupState(StrEnum):
PENDING = "PENDING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
Loading
Loading