Skip to content

Commit 79b0363

Browse files
authored
Merge pull request #26 from commit-0/ssh-modal
make ssh also work for modal
2 parents 18cb0c7 + d8c2d3b commit 79b0363

File tree

12 files changed

+266
-135
lines changed

12 files changed

+266
-135
lines changed

.github/workflows/system.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ jobs:
1919
- name: Install the project
2020
run: uv sync
2121
- name: Set up commit0
22-
run: uv run commit0 setup simpy
22+
run: uv run commit0 clone simpy
2323
- name: Build docker images
2424
run: uv run commit0 build simpy
25+
- name: Set up git user
26+
run: sudo "$(which uv)" run commit0 setup-git-user simpy
2527
- name: Get tests
2628
run: uv run commit0 get-tests simpy
2729
- name: Test

commit0/__main__.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import commit0.harness.get_pytest_ids
33
import commit0.harness.build
44
import commit0.harness.setup
5+
import commit0.harness.setup_git_user
56
import commit0.harness.evaluate
67
import commit0.harness.save
78
import copy
@@ -30,7 +31,14 @@ def main() -> None:
3031
# after hydra gets all configs, put command-line arguments back
3132
sys.argv = sys_argv
3233
# repo_split: split from command line has a higher priority than split in hydra
33-
if command in ["setup", "build", "evaluate", "evaluate-reference", "save"]:
34+
if command in [
35+
"clone",
36+
"build",
37+
"setup-git-user",
38+
"evaluate",
39+
"evaluate-reference",
40+
"save",
41+
]:
3442
if len(sys.argv) >= 3:
3543
if sys.argv[2] not in SPLIT:
3644
raise ValueError(
@@ -39,7 +47,7 @@ def main() -> None:
3947
config.repo_split = sys.argv[2]
4048
config.base_dir = os.path.abspath(config.base_dir)
4149

42-
if command == "setup":
50+
if command == "clone":
4351
commit0.harness.setup.main(
4452
config.dataset_name,
4553
config.dataset_split,
@@ -53,6 +61,17 @@ def main() -> None:
5361
config.dataset_split,
5462
config.repo_split,
5563
config.num_workers,
64+
config.backend,
65+
config.key_path,
66+
)
67+
elif command == "setup-git-user":
68+
commit0.harness.setup_git_user.main(
69+
config.dataset_name,
70+
config.dataset_split,
71+
config.repo_split,
72+
config.base_dir,
73+
config.git_user,
74+
config.key_path,
5675
)
5776
elif command == "get-tests":
5877
repo = sys.argv[2]

commit0/configs/base.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ repo_split: all
1111

1212
# build related
1313
num_workers: 8
14+
key_path: commit0/configs/public_keys.json
15+
16+
# set up git user
17+
git_user: git # by default, git user is called git
1418

1519
# test related
1620
backend: local

commit0/configs/config_class.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ class Commit0Config:
1515
# build related
1616
# which repo to build, all or one repo
1717
num_workers: int
18+
# path to store public keys from docker images
19+
key_path: str
1820

1921
# test related
2022
backend: str

commit0/harness/build.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1+
import json
12
import logging
3+
import traceback
24

35
import docker
46
from datasets import load_dataset
7+
from tqdm import tqdm
58
from typing import Iterator
69

10+
from commit0.harness.constants import EVAL_BACKENDS, RepoInstance, SPLIT
711
from commit0.harness.docker_build import build_repo_images
12+
from commit0.harness.execution_context import (
13+
ExecutionBackend,
14+
Docker,
15+
Modal,
16+
)
817
from commit0.harness.spec import make_spec
9-
from commit0.harness.constants import RepoInstance, SPLIT
1018

1119
logging.basicConfig(
1220
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -15,7 +23,12 @@
1523

1624

1725
def main(
18-
dataset_name: str, dataset_split: str, repo_split: str, num_workers: int
26+
dataset_name: str,
27+
dataset_split: str,
28+
repo_split: str,
29+
num_workers: int,
30+
backend: str,
31+
key_path: str,
1932
) -> None:
2033
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
2134
specs = []
@@ -26,8 +39,29 @@ def main(
2639
spec = make_spec(example)
2740
specs.append(spec)
2841

29-
client = docker.from_env()
30-
build_repo_images(client, specs, num_workers)
42+
if ExecutionBackend(backend) == ExecutionBackend.MODAL:
43+
execution_context = Modal
44+
elif ExecutionBackend(backend) == ExecutionBackend.LOCAL:
45+
client = docker.from_env()
46+
build_repo_images(client, specs, num_workers)
47+
execution_context = Docker
48+
else:
49+
raise ValueError(
50+
f"Evaluation must be from {', '.join(EVAL_BACKENDS)}, but {backend} is provided."
51+
)
52+
53+
# get ssh key from each docker image
54+
img2key = dict()
55+
for spec in tqdm(specs, desc="Retrieving public keys from docker images"):
56+
try:
57+
with execution_context(spec, logger, timeout=60) as context:
58+
key = context.get_ssh_pubkey_from_remote(user="root")
59+
img2key[spec.repo_image_key] = key
60+
except Exception as e:
61+
error_msg = f"General error: {e}\n" f"{traceback.format_exc()}\n"
62+
raise RuntimeError(error_msg)
63+
with open(key_path, "w") as json_file:
64+
json.dump(img2key, json_file, indent=4)
3165

3266

3367
__all__ = []

commit0/harness/constants.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ class RepoInstance(TypedDict):
1111
test: Dict[str, str]
1212

1313

14+
class Files(TypedDict):
15+
eval_script: Dict[str, Path]
16+
17+
1418
# Constants - Evaluation Log Directories
1519
BASE_IMAGE_BUILD_DIR = Path("logs/build_images/base")
1620
REPO_IMAGE_BUILD_DIR = Path("logs/build_images/repo")
@@ -27,8 +31,9 @@ class RepoInstance(TypedDict):
2731

2832
# available commands
2933
COMMANDS = [
30-
"setup",
34+
"clone",
3135
"build",
36+
"setup-git-user",
3237
"test",
3338
"test-reference",
3439
"get-tests",

commit0/harness/docker_utils.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import threading
99
import time
1010
import traceback
11-
import pwd
1211
from pathlib import Path
1312
from io import BytesIO
1413
from typing import Optional, List, Union
@@ -141,50 +140,31 @@ def delete_file_from_container(container: Container, file_path: str) -> None:
141140
raise Exception(f"General Error: {str(e)}")
142141

143142

144-
def copy_ssh_pubkey_from_container(container: Container) -> None:
143+
def get_ssh_pubkey_from_container(container: Container, user: str) -> str:
145144
"""Copy the SSH public key from a Docker container to the local authorized_keys file.
146145
147146
Args:
148147
----
149148
container (Container): Docker container to copy the key from.
149+
user (str): to get public key of which user
150+
151+
Returns:
152+
-------
153+
public_key (str): public key from docker container
150154
151155
Raises:
152156
------
153157
docker.errors.APIError: If there is an error calling the Docker API.
154-
Exception: If the file reading or writing process fails.
155158
156159
"""
157160
try:
158161
exit_code, output = container.exec_run("cat /root/.ssh/id_rsa.pub")
159162
if exit_code != 0:
160163
raise Exception(f"Error reading file: {output.decode('utf-8').strip()}")
161164
public_key = output.decode("utf-8").strip()
162-
public_key = f"no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty {public_key}"
163-
164-
user_info = pwd.getpwnam("git")
165-
home_directory = user_info.pw_dir
166-
authorized_keys_path = os.path.join(home_directory, ".ssh", "authorized_keys")
167-
os.makedirs(os.path.dirname(authorized_keys_path), exist_ok=True)
168-
if not os.path.exists(authorized_keys_path):
169-
# Since the file does not exist, create it
170-
open(authorized_keys_path, "a").close()
171-
write = True
172-
else:
173-
with open(authorized_keys_path, "r") as authorized_keys_file:
174-
content = authorized_keys_file.read()
175-
if public_key not in content:
176-
write = True
177-
else:
178-
write = False
179-
180-
if write:
181-
with open(authorized_keys_path, "a") as authorized_keys_file:
182-
authorized_keys_file.write(public_key + "\n")
183-
165+
return public_key
184166
except docker.errors.APIError as e:
185167
raise docker.errors.APIError(f"Docker API Error: {str(e)}")
186-
except Exception as e:
187-
raise Exception(f"General Error: {str(e)}")
188168

189169

190170
def write_to_container(container: Container, data: str, dst: Path) -> None:

0 commit comments

Comments
 (0)