Skip to content

Commit f9727de

Browse files
authored
Merge pull request #20 from commit-0/modal-clean
[Draft] Refactor modal code
2 parents e9cfc06 + a1bdbc1 commit f9727de

File tree

3 files changed

+328
-188
lines changed

3 files changed

+328
-188
lines changed

commit0/configs/user.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
defaults:
22
- base
33
- _self_
4+
5+
backend: local

commit0/harness/execution_context.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
"""Remote code execution contexts
2+
3+
Implements the interface for local docker containers, remote modal sandboxes,
4+
and HTTP servers.
5+
"""
6+
7+
from abc import ABC, abstractmethod
8+
import docker
9+
import logging
10+
import os
11+
import modal
12+
import modal.io_streams
13+
from pathlib import Path
14+
from typing import Optional, Type
15+
from types import TracebackType
16+
17+
from commit0.harness.spec import Spec
18+
from commit0.harness.utils import (
19+
EvaluationError,
20+
)
21+
from commit0.harness.docker_build import (
22+
close_logger,
23+
)
24+
from commit0.harness.docker_utils import (
25+
cleanup_container,
26+
create_container,
27+
copy_from_container,
28+
copy_to_container,
29+
copy_ssh_pubkey_from_container,
30+
delete_file_from_container,
31+
exec_run_with_timeout,
32+
)
33+
34+
35+
def read_stream(stream: modal.io_streams.StreamReader) -> str:
36+
"""Read stream"""
37+
strings = []
38+
for line in stream:
39+
strings.append(line)
40+
return "\n".join(strings)
41+
42+
43+
class ExecutionContext(ABC):
44+
def __init__(
45+
self,
46+
spec: Spec,
47+
logger: logging.Logger,
48+
eval_file: Path,
49+
timeout: int,
50+
log_dir: Path,
51+
):
52+
"""Create the remote execution context
53+
54+
The execution context will persist for the lifetime of this object.
55+
The execution context can be a Docker container or Modal sandbox.
56+
"""
57+
self.spec = spec
58+
self.logger = logger
59+
self.eval_file = eval_file
60+
self.timeout = timeout
61+
self.log_dir = log_dir
62+
63+
@abstractmethod
64+
def copy_ssh_pubkey_from_remote(self) -> None:
65+
"""Copy"""
66+
raise NotImplementedError
67+
68+
@abstractmethod
69+
def copy_to_remote(self, local_path: Path, remote_path: Path) -> None:
70+
"""Copy"""
71+
raise NotImplementedError
72+
73+
@abstractmethod
74+
def exec_run_with_timeout(
75+
self, command: str, timeout: int
76+
) -> tuple[str, bool, float]:
77+
"""Exec"""
78+
raise NotImplementedError
79+
80+
@abstractmethod
81+
def exec_run(self, command: str) -> tuple[int, str]:
82+
"""Exec"""
83+
raise NotImplementedError
84+
85+
@abstractmethod
86+
def copy_from_remote(self, remote_path: Path, local_path: Path) -> None:
87+
"""Copy"""
88+
raise NotImplementedError
89+
90+
@abstractmethod
91+
def delete_file_from_remote(self, remote_path: Path) -> None:
92+
"""Delete"""
93+
raise NotImplementedError
94+
95+
def write_test_output(self, test_output: str, timed_out: bool) -> None:
96+
"""Write test output"""
97+
test_output_path = self.log_dir / "test_output.txt"
98+
with open(test_output_path, "w") as f:
99+
f.write(test_output)
100+
if timed_out:
101+
f.write(f"\n\nTimeout error: {self.timeout} seconds exceeded.")
102+
raise EvaluationError(
103+
self.spec.repo,
104+
f"Test timed out after {self.timeout} seconds.",
105+
self.logger,
106+
)
107+
108+
# copy back report.json if there is any
109+
report_file = Path(self.spec.repo_directory) / "report.json"
110+
# Run the test command inside the container to check if the file exists
111+
exit_code, output = self.exec_run(f"test -e {report_file}")
112+
# Check the exit code of the command
113+
if exit_code == 0:
114+
self.copy_from_remote(report_file, self.log_dir / "report.json")
115+
self.delete_file_from_remote(report_file)
116+
117+
def __enter__(self):
118+
return self
119+
120+
@abstractmethod
121+
def __exit__(
122+
self,
123+
exctype: Optional[Type[BaseException]],
124+
excinst: Optional[BaseException],
125+
exctb: Optional[TracebackType],
126+
) -> None:
127+
raise NotImplementedError
128+
129+
130+
class Docker(ExecutionContext):
131+
def __init__(
132+
self,
133+
spec: Spec,
134+
logger: logging.Logger,
135+
eval_file: Path,
136+
timeout: int,
137+
log_dir: Path,
138+
):
139+
super().__init__(spec, logger, eval_file, timeout, log_dir)
140+
141+
self.client = docker.from_env()
142+
self.container = create_container(
143+
client=self.client,
144+
image_name=spec.repo_image_key,
145+
container_name=spec.get_container_name(),
146+
logger=logger,
147+
)
148+
self.container.start()
149+
self.copy_ssh_pubkey_from_remote()
150+
copy_to_container(self.container, eval_file, Path("/eval.sh"))
151+
152+
def copy_ssh_pubkey_from_remote(self) -> None:
153+
"""Copy"""
154+
copy_ssh_pubkey_from_container(self.container)
155+
156+
def copy_to_remote(self, local_file: Path, remote_path: Path) -> None:
157+
"""Copy"""
158+
copy_to_container(self.container, local_file, remote_path)
159+
160+
def exec_run_with_timeout(
161+
self, command: str, timeout: int
162+
) -> tuple[str, bool, float]:
163+
"""Exec"""
164+
return exec_run_with_timeout(self.container, command, timeout)
165+
166+
def exec_run(self, command: str) -> tuple[int, str]:
167+
"""Exec"""
168+
return self.container.exec_run(command, demux=True)
169+
170+
def copy_from_remote(self, remote_path: Path, local_path: Path) -> None:
171+
"""Copy"""
172+
copy_from_container(self.container, remote_path, local_path)
173+
174+
def delete_file_from_remote(self, remote_path: Path) -> None:
175+
"""Delete"""
176+
delete_file_from_container(self.container, str(remote_path))
177+
178+
def __exit__(
179+
self,
180+
exctype: Optional[Type[BaseException]],
181+
excinst: Optional[BaseException],
182+
exctb: Optional[TracebackType],
183+
) -> None:
184+
cleanup_container(self.client, self.container, self.logger)
185+
close_logger(self.logger)
186+
187+
188+
class Modal(ExecutionContext):
189+
def __init__(
190+
self,
191+
spec: Spec,
192+
logger: logging.Logger,
193+
eval_file: Path,
194+
timeout: int,
195+
log_dir: Path,
196+
):
197+
super().__init__(spec, logger, eval_file, timeout, log_dir)
198+
199+
# the image must exist on dockerhub
200+
reponame = spec.repo.split("/")[-1]
201+
image_name = f"wentingzhao/{reponame}"
202+
image = modal.Image.from_registry(image_name).copy_local_file(
203+
eval_file, "/eval.sh"
204+
)
205+
206+
self.sandbox = modal.Sandbox.create(
207+
"sleep",
208+
"infinity",
209+
image=image,
210+
cpu=4.0,
211+
timeout=timeout,
212+
)
213+
214+
self.copy_ssh_pubkey_from_remote()
215+
216+
def copy_ssh_pubkey_from_remote(self) -> None:
217+
"""Copy ssh pubkey"""
218+
process = self.sandbox.exec("bash", "-c", "cat /root/.ssh/id_rsa.pub")
219+
public_key = "".join([line for line in process.stdout]).strip()
220+
221+
# add to authorized keys locally. copy-pasted from utils
222+
local_authorized_keys_path = os.path.expanduser("~/.ssh/authorized_keys")
223+
os.makedirs(os.path.dirname(local_authorized_keys_path), exist_ok=True)
224+
if not os.path.exists(local_authorized_keys_path):
225+
# Since the file does not exist, create it
226+
open(local_authorized_keys_path, "a").close()
227+
write = True
228+
else:
229+
with open(local_authorized_keys_path, "r") as authorized_keys_file:
230+
content = authorized_keys_file.read()
231+
if public_key not in content:
232+
write = True
233+
else:
234+
write = False
235+
if write:
236+
with open(local_authorized_keys_path, "a") as authorized_keys_file:
237+
authorized_keys_file.write(public_key + "\n")
238+
239+
def copy_to_remote(self, local_path: Path, remote_path: Path) -> None:
240+
"""Copy"""
241+
raise NotImplementedError
242+
# tempname = "tmpfile"
243+
# with local_path.open("rb") as f:
244+
# self.nfs.write_file(tempname, f)
245+
# self.sandbox.exec("bash", "-c", f"cp /vol/{tempname} {str(remote_path)}")
246+
247+
def exec_run_with_timeout(
248+
self, command: str, timeout: int
249+
) -> tuple[str, bool, float]:
250+
"""Execute command on modal sandbox"""
251+
print("Executing:", command)
252+
process = self.sandbox.exec("bash", "-c", command)
253+
print("stdout")
254+
stdout = read_stream(process.stdout)
255+
print("stderr")
256+
stderr = read_stream(process.stderr)
257+
print(stderr)
258+
return stdout, False, 1.0
259+
return stdout, stderr
260+
261+
def exec_run(self, command: str) -> tuple[int, str]:
262+
"""Execute command on modal sandbox"""
263+
process = self.sandbox.exec("bash", "-c", command)
264+
stdout = read_stream(process.stdout)
265+
stderr = read_stream(process.stderr)
266+
print(stderr)
267+
return 1, stdout
268+
269+
def copy_from_remote(self, remote_path: Path, local_path: Path) -> None:
270+
"""Copy file from modal sandbox"""
271+
process = self.sandbox.exec("bash", "-c", f"cat {str(remote_path)}")
272+
output = "".join([line for line in process.stdout]).strip()
273+
with local_path.open("w") as f:
274+
f.write(output)
275+
276+
def delete_file_from_remote(self, remote_path: Path) -> None:
277+
"""Delete"""
278+
self.sandbox.exec("bash", "-c", f"rm {str(remote_path)}")
279+
280+
def __exit__(
281+
self,
282+
exctype: Optional[Type[BaseException]],
283+
excinst: Optional[BaseException],
284+
exctb: Optional[TracebackType],
285+
) -> None:
286+
self.sandbox.terminate()
287+
close_logger(self.logger)

0 commit comments

Comments
 (0)