Skip to content

Commit 6e917bd

Browse files
committed
lint and ci
1 parent ec46653 commit 6e917bd

File tree

4 files changed

+33
-15
lines changed

4 files changed

+33
-15
lines changed

.github/workflows/system.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@ jobs:
2626
run: uv run commit0 get-tests simpy
2727
- name: Test
2828
run: uv run commit0 test-reference simpy tests/test_event.py::test_succeed
29+
- name: Evaluate
30+
run: uv run commit0 evaluate-reference simpy

commit0/harness/constants.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@ class RepoInstance(TypedDict):
2626
EVAL_BACKENDS = ["local", "modal"]
2727

2828
# available commands
29-
COMMANDS = ["clone", "build", "test", "test-reference", "get-tests", "evaluate", "evaluate-reference"]
29+
COMMANDS = [
30+
"clone",
31+
"build",
32+
"test",
33+
"test-reference",
34+
"get-tests",
35+
"evaluate",
36+
"evaluate-reference",
37+
]
3038
# repo splits
3139
SPLIT_MINITORCH = ["minitorch"]
3240
SPLIT_SIMPY = ["simpy"]

commit0/harness/evaluate.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import logging
22
import os
3+
import traceback
34
from collections import Counter
45

5-
import docker
66
from concurrent.futures import ThreadPoolExecutor, as_completed
77
from datasets import load_dataset
88
from tqdm import tqdm
@@ -18,7 +18,16 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21-
def main(dataset_name: str, dataset_split: str, repo_split: str, base_dir: str, branch: str, backend: str, timeout: int, num_workers: int) -> None:
21+
def main(
22+
dataset_name: str,
23+
dataset_split: str,
24+
repo_split: str,
25+
base_dir: str,
26+
branch: str,
27+
backend: str,
28+
timeout: int,
29+
num_workers: int,
30+
) -> None:
2231
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
2332
repos = SPLIT[repo_split]
2433
pairs = []
@@ -54,15 +63,15 @@ def main(dataset_name: str, dataset_split: str, repo_split: str, base_dir: str,
5463
# Update progress bar, check if instance ran successfully
5564
result = future.result()
5665
log_dirs.append(result)
57-
except Exception as e:
66+
except Exception:
5867
traceback.print_exc()
5968
continue
6069

6170
# get numbers
6271
out = []
6372
for name in tqdm(log_dirs):
6473
report_file = os.path.join(name, "report.json")
65-
name = name.split('/')[2]
74+
name = name.split("/")[2]
6675
if not os.path.exists(report_file):
6776
out.append(
6877
{
@@ -73,9 +82,9 @@ def main(dataset_name: str, dataset_split: str, repo_split: str, base_dir: str,
7382
}
7483
)
7584
continue
76-
dataset: Iterator[RepoInstance] = load_dataset("json", data_files=report_file, split="train")
85+
report = load_dataset("json", data_files=report_file, split="train") # type: ignore
7786
test_ids = get_tests(name, stdout=False)
78-
tests = {x['nodeid']: x['call'] for x in dataset["tests"][0]}
87+
tests = {x["nodeid"]: x["call"] for x in report["tests"][0]} # type: ignore
7988
status = []
8089
runtimes = []
8190
no_runs = 0
@@ -100,18 +109,16 @@ def main(dataset_name: str, dataset_split: str, repo_split: str, base_dir: str,
100109
"name": name,
101110
"sum": total,
102111
"passed": passed,
103-
"num_passed": status["passed"]+status["xfail"],
104-
"num_tests": sum(status.values())
112+
"num_passed": status["passed"] + status["xfail"],
113+
"num_tests": sum(status.values()),
105114
}
106115
)
107116
print("repo,runtime,num_passed/num_tests")
108117
out = sorted(out, key=lambda x: x["sum"], reverse=True)
109118
for x in out:
110-
print(
111-
f"{x['name']},{x['sum']},{x['num_passed']}/{x['num_tests']}"
112-
)
119+
print(f"{x['name']},{x['sum']},{x['num_passed']}/{x['num_tests']}")
113120
total_runtime = sum([x["sum"] for x in out])
114-
averaged_passed = sum([x["passed"] for x in out])/len(out)
121+
averaged_passed = sum([x["passed"] for x in out]) / len(out)
115122
print(f"total runtime: {total_runtime}")
116123
print(f"average pass rate: {averaged_passed}")
117124

commit0/harness/get_pytest_ids.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import tarfile
2+
from typing import List
23

34

4-
def main(repo: str, stdout: bool) -> None:
5+
def main(repo: str, stdout: bool) -> List[str]:
56
repo = repo.lower()
67
repo = repo.replace(".", "-")
78
out = ""
@@ -14,7 +15,7 @@ def main(repo: str, stdout: bool) -> None:
1415
out += content
1516
if stdout:
1617
print(content)
17-
out = out.split('\n')
18+
out = out.split("\n")
1819
return out
1920

2021

0 commit comments

Comments
 (0)