Skip to content

Commit c2fafe1

Browse files
committed
pre-commit
1 parent aca446c commit c2fafe1

File tree

7 files changed

+68
-30
lines changed

7 files changed

+68
-30
lines changed

commit0/__main__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,14 @@ def main() -> None:
3131
# after hydra gets all configs, put command-line arguments back
3232
sys.argv = sys_argv
3333
# repo_split: split from command line has a higher priority than split in hydra
34-
if command in ["clone", "build", "setup-git-user", "evaluate", "evaluate-reference", "save"]:
34+
if command in [
35+
"clone",
36+
"build",
37+
"setup-git-user",
38+
"evaluate",
39+
"evaluate-reference",
40+
"save",
41+
]:
3542
if len(sys.argv) >= 3:
3643
if sys.argv[2] not in SPLIT:
3744
raise ValueError(

commit0/harness/build.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tqdm import tqdm
88
from typing import Iterator
99

10-
from commit0.harness.constants import RepoInstance, SPLIT
10+
from commit0.harness.constants import EVAL_BACKENDS, RepoInstance, SPLIT
1111
from commit0.harness.docker_build import build_repo_images
1212
from commit0.harness.execution_context import (
1313
ExecutionBackend,
@@ -23,7 +23,12 @@
2323

2424

2525
def main(
26-
dataset_name: str, dataset_split: str, repo_split: str, num_workers: int, backend: str, key_path: str
26+
dataset_name: str,
27+
dataset_split: str,
28+
repo_split: str,
29+
num_workers: int,
30+
backend: str,
31+
key_path: str,
2732
) -> None:
2833
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
2934
specs = []
@@ -40,6 +45,10 @@ def main(
4045
client = docker.from_env()
4146
build_repo_images(client, specs, num_workers)
4247
execution_context = Docker
48+
else:
49+
raise ValueError(
50+
f"Evaluation must be from {', '.join(EVAL_BACKENDS)}, but {backend} is provided."
51+
)
4352

4453
# get ssh key from each docker image
4554
img2key = dict()
@@ -49,12 +58,9 @@ def main(
4958
key = context.get_ssh_pubkey_from_remote(user="root")
5059
img2key[spec.repo_image_key] = key
5160
except Exception as e:
52-
error_msg = (
53-
f"General error: {e}\n"
54-
f"{traceback.format_exc()}\n"
55-
)
61+
error_msg = f"General error: {e}\n" f"{traceback.format_exc()}\n"
5662
raise RuntimeError(error_msg)
57-
with open(key_path, 'w') as json_file:
63+
with open(key_path, "w") as json_file:
5864
json.dump(img2key, json_file, indent=4)
5965

6066

commit0/harness/docker_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import threading
99
import time
1010
import traceback
11-
import pwd
1211
from pathlib import Path
1312
from io import BytesIO
1413
from typing import Optional, List, Union

commit0/harness/execution_context.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from abc import ABC, abstractmethod
88
import docker
99
import logging
10-
import os
1110
import modal
1211
import modal.io_streams
1312
from enum import StrEnum, auto
@@ -90,7 +89,9 @@ def delete_file_from_remote(self, remote_path: Path) -> None:
9089
"""Delete"""
9190
raise NotImplementedError
9291

93-
def write_test_output(self, log_dir: Path, test_output: str, timed_out: bool) -> None:
92+
def write_test_output(
93+
self, log_dir: Path, test_output: str, timed_out: bool
94+
) -> None:
9495
"""Write test output"""
9596
test_output_path = log_dir / "test_output.txt"
9697
with open(test_output_path, "w") as f:
@@ -145,9 +146,9 @@ def __init__(
145146
self.container.start()
146147
if files_to_copy:
147148
for _, f in files_to_copy.items():
148-
copy_to_container(self.container, f['src'], Path(f['dest']))
149+
copy_to_container(self.container, f["src"], f["dest"]) # type: ignore
149150

150-
def get_ssh_pubkey_from_remote(self, user: str) -> None:
151+
def get_ssh_pubkey_from_remote(self, user: str) -> str:
151152
"""Copy"""
152153
return get_ssh_pubkey_from_container(self.container, user)
153154

@@ -195,7 +196,7 @@ def __init__(
195196
image = modal.Image.from_registry(image_name)
196197
if files_to_copy:
197198
for _, f in files_to_copy.items():
198-
image = image.copy_local_file(f['src'], f['dest'])
199+
image = image.copy_local_file(f["src"], f["dest"]) # type: ignore
199200

200201
self.sandbox = modal.Sandbox.create(
201202
"sleep",
@@ -205,7 +206,7 @@ def __init__(
205206
timeout=timeout,
206207
)
207208

208-
def get_ssh_pubkey_from_remote(self, user: str) -> None:
209+
def get_ssh_pubkey_from_remote(self, user: str) -> str:
209210
"""Copy ssh pubkey"""
210211
process = self.sandbox.exec("bash", "-c", f"cat /{user}/.ssh/id_rsa.pub")
211212
public_key = "".join([line for line in process.stdout]).strip()

commit0/harness/run_pytest_ids.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55
from typing import Iterator
66
from git import Repo
7-
from commit0.harness.constants import Files, RUN_PYTEST_LOG_DIR, RepoInstance
7+
from commit0.harness.constants import (
8+
EVAL_BACKENDS,
9+
Files,
10+
RUN_PYTEST_LOG_DIR,
11+
RepoInstance,
12+
)
813
from commit0.harness.docker_build import (
914
setup_logger,
1015
)
@@ -76,6 +81,10 @@ def main(
7681
execution_context = Modal
7782
elif ExecutionBackend(backend) == ExecutionBackend.LOCAL:
7883
execution_context = Docker
84+
else:
85+
raise ValueError(
86+
f"Evaluation must be from {', '.join(EVAL_BACKENDS)}, but {backend} is provided."
87+
)
7988

8089
files_to_copy = Files(eval_script={"src": eval_file, "dest": Path("/eval.sh")})
8190

commit0/harness/setup_git_user.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
22
import logging
33
import os
4-
import traceback
54

65
from datasets import load_dataset
76

@@ -22,11 +21,18 @@
2221
logger = logging.getLogger(__name__)
2322

2423

25-
def main(dataset_name: str, dataset_split: str, repo_split: str, base_dir: str, git_user: str, key_path: str) -> None:
24+
def main(
25+
dataset_name: str,
26+
dataset_split: str,
27+
repo_split: str,
28+
base_dir: str,
29+
git_user: str,
30+
key_path: str,
31+
) -> None:
2632
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
2733
setup_user(git_user, logger)
2834
setup_ssh_directory(git_user, logger)
29-
with open(key_path, 'r') as f:
35+
with open(key_path, "r") as f:
3036
public_keys = json.load(f)
3137

3238
for example in dataset:

commit0/harness/utils.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,36 +114,42 @@ def setup_user(user: str, logger: logging.Logger) -> None:
114114

115115

116116
def chmod(path: str, mode: int, logger: logging.Logger) -> None:
117-
"""
118-
A Python wrapper for the chmod command to change file or directory permissions.
117+
"""A Python wrapper for the chmod command to change file or directory permissions.
119118
120119
Args:
120+
----
121121
path (str): The path to the file or directory.
122122
mode (int): The permission mode (octal), e.g., 0o755, 0o644, etc.
123123
logger (logging.Logger): The logger object.
124+
124125
"""
125126
try:
126127
os.chmod(path, mode)
127128
logger.info(f"Permissions for '{path}' changed to {oct(mode)}")
128129
except FileNotFoundError:
129-
raise FileNotFoundError(f"Error: The file or directory '{path}' does not exist.")
130+
raise FileNotFoundError(
131+
f"Error: The file or directory '{path}' does not exist."
132+
)
130133
except PermissionError:
131-
raise PermissionError(f"Error: Permission denied when changing permissions for '{path}'")
134+
raise PermissionError(
135+
f"Error: Permission denied when changing permissions for '{path}'"
136+
)
132137
except Exception as e:
133138
raise Exception(f"An error occurred: {e}")
134139

135140

136141
def setup_ssh_directory(user: str, logger: logging.Logger) -> None:
137-
"""
138-
Sets up the .ssh directory for the user and sets appropriate permissions.
142+
"""Sets up the .ssh directory for the user and sets appropriate permissions.
139143
140144
Args:
145+
----
141146
user (str): The name of the user.
142147
logger (logging.Logger): The logger object.
148+
143149
"""
144150
home = get_home_directory(user)
145-
ssh_dir = os.path.join(home, '.ssh')
146-
authorized_keys_file = os.path.join(ssh_dir, 'authorized_keys')
151+
ssh_dir = os.path.join(home, ".ssh")
152+
authorized_keys_file = os.path.join(ssh_dir, "authorized_keys")
147153

148154
try:
149155
# Create the .ssh directory if it doesn't exist
@@ -156,19 +162,23 @@ def setup_ssh_directory(user: str, logger: logging.Logger) -> None:
156162

157163
# Create the authorized_keys file if it doesn't exist
158164
if not os.path.exists(authorized_keys_file):
159-
open(authorized_keys_file, 'a').close()
165+
open(authorized_keys_file, "a").close()
160166
logger.info(f"Created file: {authorized_keys_file}")
161167
except Exception as e:
162168
raise e
163169

164170

165171
def add_key(user: str, public_key: str) -> None:
166-
public_key = f"no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty {public_key}"
172+
public_key = (
173+
f"no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty {public_key}"
174+
)
167175

168176
home_directory = get_home_directory(user)
169177
authorized_keys_path = os.path.join(home_directory, ".ssh", "authorized_keys")
170178
if not os.path.exists(authorized_keys_path):
171-
raise FileNotFoundError(f"f{authorized_keys_path} does not exists, please call setup_ssh_directory() before adding keys")
179+
raise FileNotFoundError(
180+
f"f{authorized_keys_path} does not exists, please call setup_ssh_directory() before adding keys"
181+
)
172182
else:
173183
with open(authorized_keys_path, "r") as authorized_keys_file:
174184
content = authorized_keys_file.read()

0 commit comments

Comments
 (0)