Skip to content

Commit 73b1de0

Browse files
authored
Merge pull request #21 from commit-0/evaluate
Evaluate
2 parents 22633fc + 6e917bd commit 73b1de0

File tree

6 files changed

+187
-15
lines changed

6 files changed

+187
-15
lines changed

.github/workflows/system.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@ jobs:
2626
run: uv run commit0 get-tests simpy
2727
- name: Test
2828
run: uv run commit0 test-reference simpy tests/test_event.py::test_succeed
29+
- name: Evaluate
30+
run: uv run commit0 evaluate-reference simpy

commit0/__main__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import commit0.harness.get_pytest_ids
33
import commit0.harness.build
44
import commit0.harness.setup
5+
import commit0.harness.evaluate
56
import copy
67
import sys
78
import os
@@ -28,7 +29,7 @@ def main() -> None:
2829
# after hydra gets all configs, put command-line arguments back
2930
sys.argv = sys_argv
3031
# repo_split: split from command line has a higher priority than split in hydra
31-
if command in ["clone", "build"]:
32+
if command in ["clone", "build", "evaluate", "evaluate-reference"]:
3233
if len(sys.argv) == 3:
3334
if sys.argv[2] not in SPLIT:
3435
raise ValueError(
@@ -53,7 +54,7 @@ def main() -> None:
5354
)
5455
elif command == "get-tests":
5556
repo = sys.argv[2]
56-
commit0.harness.get_pytest_ids.main(repo)
57+
commit0.harness.get_pytest_ids.main(repo, stdout=True)
5758
elif command == "test" or command == "test-reference":
5859
repo = sys.argv[2]
5960
test_ids = sys.argv[3]
@@ -68,6 +69,20 @@ def main() -> None:
6869
test_ids,
6970
config.backend,
7071
config.timeout,
72+
stdout=True,
73+
)
74+
elif command == "evaluate" or command == "evaluate-reference":
75+
if command == "evaluate-reference":
76+
config.branch = "reference"
77+
commit0.harness.evaluate.main(
78+
config.dataset_name,
79+
config.dataset_split,
80+
config.repo_split,
81+
config.base_dir,
82+
config.branch,
83+
config.backend,
84+
config.timeout,
85+
config.num_workers,
7186
)
7287

7388

commit0/harness/constants.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@ class RepoInstance(TypedDict):
2626
EVAL_BACKENDS = ["local", "modal"]
2727

2828
# available commands
29-
COMMANDS = ["clone", "build", "test", "test-reference", "get-tests"]
29+
COMMANDS = [
30+
"clone",
31+
"build",
32+
"test",
33+
"test-reference",
34+
"get-tests",
35+
"evaluate",
36+
"evaluate-reference",
37+
]
3038
# repo splits
3139
SPLIT_MINITORCH = ["minitorch"]
3240
SPLIT_SIMPY = ["simpy"]
@@ -80,7 +88,8 @@ class RepoInstance(TypedDict):
8088
"mimesis",
8189
"babel",
8290
"dnspython",
83-
"portalocker," "cookiecutter",
91+
"portalocker",
92+
"cookiecutter",
8493
"pyjwt",
8594
"python-rsa",
8695
"more-itertools",

commit0/harness/evaluate.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import logging
2+
import os
3+
import traceback
4+
from collections import Counter
5+
6+
from concurrent.futures import ThreadPoolExecutor, as_completed
7+
from datasets import load_dataset
8+
from tqdm import tqdm
9+
from typing import Iterator
10+
11+
from commit0.harness.run_pytest_ids import main as run_tests
12+
from commit0.harness.get_pytest_ids import main as get_tests
13+
from commit0.harness.constants import RepoInstance, SPLIT
14+
15+
logging.basicConfig(
16+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
17+
)
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def main(
22+
dataset_name: str,
23+
dataset_split: str,
24+
repo_split: str,
25+
base_dir: str,
26+
branch: str,
27+
backend: str,
28+
timeout: int,
29+
num_workers: int,
30+
) -> None:
31+
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
32+
repos = SPLIT[repo_split]
33+
pairs = []
34+
for example in dataset:
35+
repo_name = example["repo"].split("/")[-1]
36+
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
37+
continue
38+
pairs.append((repo_name, example["test"]["test_dir"]))
39+
40+
log_dirs = []
41+
with tqdm(total=len(repos), smoothing=0, desc="Evaluating repos") as pbar:
42+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
43+
# Create a future for running each instance
44+
futures = {
45+
executor.submit(
46+
run_tests,
47+
dataset_name,
48+
dataset_split,
49+
base_dir,
50+
repo,
51+
branch,
52+
test_dir,
53+
backend,
54+
timeout,
55+
stdout=False,
56+
): None
57+
for repo, test_dir in pairs
58+
}
59+
# Wait for each future to complete
60+
for future in as_completed(futures):
61+
pbar.update(1)
62+
try:
63+
# Update progress bar, check if instance ran successfully
64+
result = future.result()
65+
log_dirs.append(result)
66+
except Exception:
67+
traceback.print_exc()
68+
continue
69+
70+
# get numbers
71+
out = []
72+
for name in tqdm(log_dirs):
73+
report_file = os.path.join(name, "report.json")
74+
name = name.split("/")[2]
75+
if not os.path.exists(report_file):
76+
out.append(
77+
{
78+
"name": name,
79+
"sum": 0,
80+
"passed": 0,
81+
"num_passed": 0,
82+
}
83+
)
84+
continue
85+
report = load_dataset("json", data_files=report_file, split="train") # type: ignore
86+
test_ids = get_tests(name, stdout=False)
87+
tests = {x["nodeid"]: x["call"] for x in report["tests"][0]} # type: ignore
88+
status = []
89+
runtimes = []
90+
no_runs = 0
91+
for test_id in test_ids:
92+
if test_id in tests and tests[test_id] is not None:
93+
status.append(tests[test_id]["outcome"])
94+
runtimes.append(tests[test_id]["duration"])
95+
no_runs += 1
96+
else:
97+
status.append("failed")
98+
runtimes.append(0)
99+
status = Counter(status)
100+
if no_runs == 0:
101+
total = 0
102+
else:
103+
total = sum(runtimes)
104+
if "xfail" not in status:
105+
status["xfail"] = 0
106+
passed = (status["passed"] + status["xfail"]) / sum(status.values())
107+
out.append(
108+
{
109+
"name": name,
110+
"sum": total,
111+
"passed": passed,
112+
"num_passed": status["passed"] + status["xfail"],
113+
"num_tests": sum(status.values()),
114+
}
115+
)
116+
print("repo,runtime,num_passed/num_tests")
117+
out = sorted(out, key=lambda x: x["sum"], reverse=True)
118+
for x in out:
119+
print(f"{x['name']},{x['sum']},{x['num_passed']}/{x['num_tests']}")
120+
total_runtime = sum([x["sum"] for x in out])
121+
averaged_passed = sum([x["passed"] for x in out]) / len(out)
122+
print(f"total runtime: {total_runtime}")
123+
print(f"average pass rate: {averaged_passed}")
124+
125+
126+
__all__ = []

commit0/harness/get_pytest_ids.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import tarfile
2+
from typing import List
23

34

4-
def main(repo: str) -> None:
5+
def main(repo: str, stdout: bool) -> List[str]:
56
repo = repo.lower()
67
repo = repo.replace(".", "-")
8+
out = ""
79
with tarfile.open(f"commit0/data/test_ids/{repo}.tar.bz2", "r:bz2") as tar:
810
for member in tar.getmembers():
911
if member.isfile():
1012
file = tar.extractfile(member)
1113
if file:
12-
content = file.read()
13-
print(content.decode("utf-8"))
14+
content = file.read().decode("utf-8")
15+
out += content
16+
if stdout:
17+
print(content)
18+
out = out.split("\n")
19+
return out
1420

1521

1622
__all__ = []

commit0/harness/run_pytest_ids.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ class ExecutionBackend(StrEnum):
3838

3939

4040
def run_docker(
41-
spec: Spec, logger: logging.Logger, eval_file: Path, timeout: int, log_dir: Path
41+
spec: Spec,
42+
logger: logging.Logger,
43+
eval_file: Path,
44+
timeout: int,
45+
log_dir: Path,
46+
stdout: bool,
4247
) -> None:
4348
client = docker.from_env()
4449
container = None
@@ -65,7 +70,8 @@ def run_docker(
6570
output, "--json-report --json-report-file=report.json"
6671
)
6772
# stdout might be more straightforward
68-
print(test_output)
73+
if stdout:
74+
print(test_output)
6975
test_output_path = log_dir / "test_output.txt"
7076
with open(test_output_path, "w") as f:
7177
f.write(test_output)
@@ -105,7 +111,12 @@ def run_docker(
105111

106112

107113
def run_modal(
108-
spec: Spec, logger: logging.Logger, eval_file: Path, timeout: int, log_dir: Path
114+
spec: Spec,
115+
logger: logging.Logger,
116+
eval_file: Path,
117+
timeout: int,
118+
log_dir: Path,
119+
stdout: bool,
109120
) -> None:
110121
# get image name to pull from dockerhub
111122
# spec.repo_image_key
@@ -182,7 +193,8 @@ def run_modal(
182193
)
183194

184195
# stdout might be more straightforward
185-
print(test_output)
196+
if stdout:
197+
print(test_output)
186198
test_output_path = log_dir / "test_output.txt"
187199
with open(test_output_path, "w") as f:
188200
f.write(test_output)
@@ -204,7 +216,8 @@ def main(
204216
test_ids: str,
205217
backend: str,
206218
timeout: int,
207-
) -> None:
219+
stdout: bool,
220+
) -> str:
208221
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
209222
spec = None
210223
example = None
@@ -217,7 +230,7 @@ def main(
217230

218231
hashed_test_ids = get_hash_string(test_ids)
219232
# set up logging
220-
log_dir = RUN_PYTEST_LOG_DIR / repo / hashed_test_ids
233+
log_dir = RUN_PYTEST_LOG_DIR / repo / branch / hashed_test_ids
221234
log_dir.mkdir(parents=True, exist_ok=True)
222235
log_file = log_dir / "run_pytest.log"
223236
logger = setup_logger(repo, log_file)
@@ -241,9 +254,10 @@ def main(
241254
eval_file.write_text(eval_script)
242255

243256
if ExecutionBackend(backend) == ExecutionBackend.LOCAL:
244-
run_docker(spec, logger, eval_file, timeout, log_dir)
257+
run_docker(spec, logger, eval_file, timeout, log_dir, stdout)
245258
elif ExecutionBackend(backend) == ExecutionBackend.MODAL:
246-
run_modal(spec, logger, eval_file, timeout, log_dir)
259+
run_modal(spec, logger, eval_file, timeout, log_dir, stdout)
260+
return str(log_dir)
247261

248262

249263
__all__ = []

0 commit comments

Comments
 (0)