Skip to content

Commit 73f46bc

Browse files
authored
Add runner module (#53)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [x] I have added tests for my changes - [x] I have updated the documentation or added new documentation as needed - [x] I have read and agree to the [Contributor License Agreement](../CLA.md)
1 parent 38ad575 commit 73f46bc

File tree

17 files changed

+1134
-0
lines changed

17 files changed

+1134
-0
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ dependencies = [
6363
"PyGithub==2.5.0",
6464
"GitPython==3.1.44",
6565
"psutil>=5.8.0",
66+
"fastapi[standard]<1.0.0,>=0.115.2",
67+
"starlette<1.0.0,>=0.16.0",
6668
]
6769
license = {file = "LICENSE"}
6870
classifiers = [

src/codegen/git/utils/branch_sync.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import logging
2+
from enum import StrEnum
3+
4+
from git.remote import Remote
5+
6+
from codegen.git.configs.constants import HIGHSIDE_REMOTE_NAME
7+
from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator
8+
from codegen.git.schemas.enums import FetchResult
9+
from codegen.git.schemas.github import GithubType
10+
from codegen.git.utils.clone_url import get_authenticated_clone_url_for_repo_config
11+
from codegen.utils.performance.stopwatch_utils import stopwatch
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class BranchSyncResult(StrEnum):
17+
SUCCESS = "SUCCESS"
18+
BRANCH_NOT_FOUND = "BRANCH_NOT_FOUND"
19+
SKIP = "SKIP"
20+
21+
22+
def get_highside_origin(op: RemoteRepoOperator) -> Remote:
23+
remote_url = get_authenticated_clone_url_for_repo_config(op.repo_config, github_type=GithubType.Github)
24+
25+
if HIGHSIDE_REMOTE_NAME in op.git_cli.remotes:
26+
highside_origin = op.git_cli.remote(HIGHSIDE_REMOTE_NAME)
27+
highside_origin.set_url(remote_url)
28+
else:
29+
highside_origin = op.git_cli.create_remote(HIGHSIDE_REMOTE_NAME, remote_url)
30+
return highside_origin
31+
32+
33+
@stopwatch
34+
def fetch_highside_branch(op: RemoteRepoOperator, branch_name: str) -> FetchResult:
35+
"""Checks out a a branch from highside origin"""
36+
# Step 1: create highside origin
37+
remote_url = get_authenticated_clone_url_for_repo_config(repo=op.repo_config, github_type=GithubType.Github)
38+
op.create_remote(HIGHSIDE_REMOTE_NAME, remote_url)
39+
40+
# Step 2: fetch the branch from highside
41+
res = op.fetch_remote(HIGHSIDE_REMOTE_NAME, refspec=branch_name)
42+
if res == FetchResult.REFSPEC_NOT_FOUND:
43+
logger.warning(f"Branch: {branch_name} not found in highside. Skipping fetch.")
44+
return FetchResult.REFSPEC_NOT_FOUND
45+
46+
# Step 3: checkout (or update existing) local branch that tracks highside remote
47+
if op.is_branch_checked_out(branch_name):
48+
# update currently checked out branch to match the latest highside branch
49+
op.git_cli.git.reset("--hard", f"{HIGHSIDE_REMOTE_NAME}/{branch_name}")
50+
else:
51+
# create a new local branch that tracks the remote highside branch
52+
op.git_cli.create_head(branch_name, commit=f"{HIGHSIDE_REMOTE_NAME}/{branch_name}", force=True)
53+
return FetchResult.SUCCESS

src/codegen/runner/__init__.py

Whitespace-only changes.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Environment variables used in the sandbox."""
2+
3+
# ==== [ Environment variable names ] ====
4+
CUSTOMER_REPO_ID = "CUSTOMER_REPO_ID"
5+
FEATURE_FLAGS_BASE64 = "FEATURE_FLAGS_BASE64"
6+
REPO_CONFIG_BASE64 = "REPO_CONFIG_BASE64"
7+
LOWSIDE_TOKEN = "LOWSIDE_TOKEN"
8+
HIGHSIDE_TOKEN = "HIGHSIDE_TOKEN"
9+
IS_SANDBOX = "IS_SANDBOX"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import io
2+
import logging
3+
4+
from unidiff import LINE_TYPE_CONTEXT, Hunk, PatchedFile, PatchSet
5+
from unidiff.patch import Line
6+
7+
from codegen.sdk.core.codebase import Codebase
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def append_flag(file: PatchedFile, append_at: int, line_no: int, codebase: Codebase) -> None:
13+
added_hunk = Hunk(
14+
src_start=line_no,
15+
src_len=1,
16+
tgt_start=line_no,
17+
tgt_len=1,
18+
)
19+
line = codebase.get_file(file.path).content.split("\n")[line_no - 1]
20+
added_hunk.append(Line(f"{line}\n", line_type=LINE_TYPE_CONTEXT))
21+
file.insert(append_at, added_hunk)
22+
23+
24+
def patch_to_limited_diff_string(patch, codebase: Codebase, max_lines=10000):
25+
diff_lines = []
26+
total_lines = 0
27+
28+
# Add flags that are not in the diff
29+
filenames = [patched_file.path for patched_file in patch]
30+
flags_not_in_diff = list(filter(lambda flag: flag.symbol.filepath not in filenames, codebase.G.flags._flags))
31+
32+
for flag in flags_not_in_diff:
33+
filename = flag.symbol.filepath
34+
patched_file = PatchedFile(
35+
patch_info=f"diff --git a/{filename} b/{filename}\n",
36+
source=f"a/{filename}",
37+
target=f"b/{filename}",
38+
)
39+
patch.append(patched_file)
40+
41+
for patched_file in patch:
42+
filtered_flags = filter(lambda flag: flag.symbol.filepath == patched_file.path, codebase.G.flags._flags)
43+
sorted_flags = list(map(lambda flag: flag.symbol.start_point.row + 1, filtered_flags))
44+
sorted_flags.sort()
45+
46+
for flag in sorted_flags:
47+
is_in_diff = False
48+
49+
for i, hunk in enumerate(patched_file):
50+
contains_flag = hunk.source_start <= flag <= hunk.source_start + hunk.source_length
51+
52+
if contains_flag:
53+
is_in_diff = True
54+
break
55+
56+
is_after_flag = hunk.source_start > flag
57+
58+
if is_after_flag:
59+
is_in_diff = True
60+
append_flag(patched_file, i, flag, codebase)
61+
break
62+
63+
if not is_in_diff:
64+
append_flag(patched_file, len(patched_file), flag, codebase)
65+
66+
# Add file header
67+
raw_diff = str(patched_file)
68+
diff_length = len(raw_diff.splitlines())
69+
70+
total_lines += diff_length
71+
diff_lines.append(raw_diff)
72+
73+
if total_lines >= max_lines:
74+
break
75+
76+
return "\n".join(diff_lines)
77+
78+
79+
def get_raw_diff(codebase: Codebase, base: str = "HEAD", max_lines: int = 10000) -> str:
80+
raw_diff = codebase.get_diff(base)
81+
patch_set = PatchSet(io.StringIO(raw_diff))
82+
83+
raw_diff_length = len(raw_diff.split("\n"))
84+
logger.info(f"Truncating diff (total: {raw_diff_length}) to {max_lines} lines ...")
85+
raw_diff_trunc = patch_to_limited_diff_string(patch=patch_set, max_lines=max_lines, codebase=codebase)
86+
87+
return raw_diff_trunc
88+
89+
90+
def get_filenames_from_diff(diff: str) -> list[str]:
91+
patch_set = PatchSet(io.StringIO(diff))
92+
filenames = [patched_file.path for patched_file in patch_set]
93+
94+
return filenames
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import io
2+
import json
3+
import logging
4+
import os
5+
import select
6+
import subprocess
7+
import time
8+
9+
from unidiff import PatchedFile, PatchSet
10+
11+
from codegen.utils.performance.stopwatch_utils import stopwatch
12+
13+
logger = logging.getLogger(__name__)
14+
15+
HIGHLIGHTED_DIFF_FILENAME = "highlighted_diff.json"
16+
17+
18+
@stopwatch
19+
def syntax_highlight_modified_files(codebase, raw_diff: str, flags: list[dict]) -> str:
20+
modified_files = PatchSet(io.StringIO(raw_diff))
21+
highlighted_files = {}
22+
highlighted_diff_files = {}
23+
24+
# TODO: refactor this
25+
with subprocess.Popen(
26+
". ~/.bashrc > /dev/null && nvm use > /dev/null && yarn run --silent highlight",
27+
shell=True,
28+
cwd="/codegen/codegen-frontend/app/modules/syntaxHighlight",
29+
stdin=subprocess.PIPE,
30+
stdout=subprocess.PIPE,
31+
stderr=subprocess.PIPE,
32+
text=True,
33+
bufsize=1,
34+
universal_newlines=True,
35+
) as highlighter:
36+
poll = select.poll()
37+
poll.register(highlighter.stdout, select.POLLIN)
38+
39+
for file in modified_files:
40+
filename = file.path
41+
modified_filename = file.target_file
42+
highlighted_files[filename] = (
43+
"" if file.is_removed_file else _highlight_file(highlighter, poll, modified_filename if not modified_filename.startswith("b/") else modified_filename[2:], flags)
44+
)
45+
46+
codebase.stash_changes()
47+
48+
for file in modified_files:
49+
filename = file.path
50+
original_filename = file.source_file
51+
original = "" if file.is_added_file else _highlight_file(highlighter, poll, original_filename if not original_filename.startswith("a/") else original_filename[2:], flags)
52+
modified = highlighted_files[filename]
53+
highlighted_hunks = _construct_diff_highlight(codebase, original.splitlines(), modified.splitlines(), file)
54+
highlighted_diff_files[filename] = highlighted_hunks
55+
56+
try:
57+
codebase.restore_stashed_changes()
58+
except Exception as e:
59+
# This can happen if there are no changes stashed in the first place
60+
logger.warning(f"Error restoring stashed changes: {e}")
61+
62+
_, err = highlighter.communicate()
63+
returncode = highlighter.returncode
64+
65+
if err:
66+
logger.error(f"Highlighter exited with error: {err}")
67+
68+
if returncode != 0:
69+
raise Exception(f"Highlighter exited with code {returncode}")
70+
71+
highlighted_diff = json.dumps(highlighted_diff_files)
72+
logger.info(f"Generated highlighted diff (size={len(highlighted_diff)})")
73+
return highlighted_diff
74+
75+
76+
@stopwatch
77+
def _highlight_file(highlighter: subprocess.Popen[str], poll: select.poll, filename: str, flags: list[dict]):
78+
stdin_input = {
79+
"file": f"{os.getcwd()}/{filename}",
80+
"flags": list(filter(lambda flag: flag["filepath"] == filename, flags)),
81+
}
82+
stdin_input = json.dumps(stdin_input)
83+
84+
logger.info(f"> Highlighting {filename}...")
85+
highlighter.stdin.write(f"{stdin_input}\n")
86+
highlighter.stdin.flush()
87+
highlighted = ""
88+
89+
while True:
90+
# if monotonic.monotonic() > timeout_at:
91+
# raise Exception("Syntax highlighter timed out")
92+
#
93+
# poll_result = poll.poll(0.01)
94+
#
95+
# if not poll_result:
96+
# continue
97+
98+
# TODO: this can deadlock in case the subprocess does not write a newline
99+
line = highlighter.stdout.readline()
100+
101+
if not line:
102+
time.sleep(0.01)
103+
104+
if line == "\x03\n":
105+
break
106+
107+
highlighted += line
108+
109+
return highlighted
110+
111+
112+
def _construct_diff_highlight(codebase, source: list[str], target: list[str], patched_file: PatchedFile) -> list:
113+
original_lines = 0
114+
modified_lines = 0
115+
full_file = ""
116+
full_file_lines = 0
117+
highlighted_hunks = []
118+
119+
for hunk in patched_file:
120+
hunk_lines = ""
121+
122+
while original_lines < (hunk.source_start - 1):
123+
full_file += f" {source[original_lines]}\n"
124+
full_file_lines += 1
125+
original_lines += 1
126+
modified_lines += 1
127+
128+
for line in hunk:
129+
if line.is_removed:
130+
full_file += f"-{source[original_lines]}\n"
131+
hunk_lines += f"-{source[original_lines]}\n"
132+
original_lines += 1
133+
full_file_lines += 1
134+
elif line.is_added:
135+
full_file += f"+{target[modified_lines]}\n"
136+
hunk_lines += f"+{target[modified_lines]}\n"
137+
modified_lines += 1
138+
full_file_lines += 1
139+
else:
140+
if len(source) > original_lines:
141+
full_file += f" {source[original_lines]}\n"
142+
hunk_lines += f" {source[original_lines]}\n"
143+
elif len(target) > modified_lines:
144+
full_file += f" {target[modified_lines]}\n"
145+
hunk_lines += f" {target[modified_lines]}\n"
146+
else:
147+
logger.warning(f"Lines {original_lines}/{modified_lines} not found in {patched_file.path} in {codebase.current_commit.hexsha}: {line}")
148+
original_lines += 1
149+
modified_lines += 1
150+
full_file_lines += 1
151+
152+
if hunk_lines.endswith("\n"):
153+
hunk_lines = hunk_lines[:-1]
154+
155+
highlighted_hunks.append({"lines": hunk_lines, "starts_at": full_file_lines - len(hunk), "ends_at": full_file_lines - 1})
156+
157+
if original_lines < len(source):
158+
full_file += "\n ".join(source[original_lines:])
159+
160+
# TODO: we know the file length so we can add a property to diff and determine if we can expand down even if we haven't loaded the entire file on FE yet
161+
162+
return highlighted_hunks
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from enum import StrEnum
2+
3+
4+
class WarmupState(StrEnum):
5+
PENDING = "PENDING"
6+
COMPLETED = "COMPLETED"
7+
FAILED = "FAILED"

0 commit comments

Comments
 (0)