Skip to content

Commit d96f5c0

Browse files
authored
Merge pull request #24 from commit-0/save
Save
2 parents 8160bf1 + 587d8c1 commit d96f5c0

File tree

7 files changed

+140
-2
lines changed

7 files changed

+140
-2
lines changed

.github/workflows/system.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,8 @@ jobs:
2828
run: uv run commit0 test-reference simpy tests/test_event.py::test_succeed
2929
- name: Evaluate
3030
run: uv run commit0 evaluate-reference simpy
31+
- name: Save
32+
env:
33+
GITHUB_TOKEN: ${{ secrets.MY_GITHUB_TOKEN }}
34+
run: |
35+
uv run commit0 save simpy test-save-commit0

commit0/__main__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import commit0.harness.build
44
import commit0.harness.setup
55
import commit0.harness.evaluate
6+
import commit0.harness.save
67
import copy
78
import sys
89
import os
@@ -29,8 +30,8 @@ def main() -> None:
2930
# after hydra gets all configs, put command-line arguments back
3031
sys.argv = sys_argv
3132
# repo_split: split from command line has a higher priority than split in hydra
32-
if command in ["clone", "build", "evaluate", "evaluate-reference"]:
33-
if len(sys.argv) == 3:
33+
if command in ["clone", "build", "evaluate", "evaluate-reference", "save"]:
34+
if len(sys.argv) >= 3:
3435
if sys.argv[2] not in SPLIT:
3536
raise ValueError(
3637
f"repo split must be from {', '.join(SPLIT.keys())}, but you provided {sys.argv[2]}"
@@ -85,6 +86,17 @@ def main() -> None:
8586
config.timeout,
8687
config.num_workers,
8788
)
89+
elif command == "save":
90+
organization = sys.argv[3]
91+
commit0.harness.save.main(
92+
config.dataset_name,
93+
config.dataset_split,
94+
config.repo_split,
95+
config.base_dir,
96+
organization,
97+
config.branch,
98+
config.github_token,
99+
)
88100

89101

90102
if __name__ == "__main__":

commit0/configs/base.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ num_workers: 8
1616
backend: local
1717
branch: ai
1818
timeout: 1_800
19+
20+
# save related
21+
github_token: null

commit0/configs/config_class.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from typing import Optional
23

34

45
@dataclass
@@ -21,3 +22,6 @@ class Commit0Config:
2122
branch: str
2223
# timeout for running pytest
2324
timeout: int
25+
26+
# save related
27+
github_token: Optional[str]

commit0/harness/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class RepoInstance(TypedDict):
3434
"get-tests",
3535
"evaluate",
3636
"evaluate-reference",
37+
"save",
3738
]
3839
# repo splits
3940
SPLIT_MINITORCH = ["minitorch"]

commit0/harness/save.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import logging
2+
import os
3+
4+
import git
5+
6+
from datasets import load_dataset
7+
from typing import Iterator
8+
from commit0.harness.constants import RepoInstance, SPLIT
9+
from commit0.harness.utils import create_repo_on_github
10+
11+
12+
logging.basicConfig(
13+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
14+
)
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def main(
19+
dataset_name: str,
20+
dataset_split: str,
21+
repo_split: str,
22+
base_dir: str,
23+
organization: str,
24+
branch: str,
25+
github_token: str,
26+
) -> None:
27+
if github_token is None:
28+
# Get GitHub token from environment variable if not provided
29+
github_token = os.environ.get("GITHUB_TOKEN")
30+
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
31+
for example in dataset:
32+
repo_name = example["repo"].split("/")[-1]
33+
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
34+
continue
35+
local_repo_path = f"{base_dir}/{repo_name}"
36+
github_repo_url = f"https://github.com/{organization}/{repo_name}.git"
37+
github_repo_url = github_repo_url.replace(
38+
"https://", f"https://x-access-token:{github_token}@"
39+
)
40+
41+
# Initialize the local repository if it is not already initialized
42+
if not os.path.exists(local_repo_path):
43+
raise OSError(f"{local_repo_path} does not exists")
44+
else:
45+
repo = git.Repo(local_repo_path)
46+
47+
# create Github repo
48+
create_repo_on_github(
49+
organization=organization, repo=repo_name, logger=logger, token=github_token
50+
)
51+
# Add your remote repository URL
52+
remote_name = "progress-tracker"
53+
if remote_name not in [remote.name for remote in repo.remotes]:
54+
repo.create_remote(remote_name, url=github_repo_url)
55+
else:
56+
logger.info(
57+
f"Remote {remote_name} already exists, replacing it with {github_repo_url}"
58+
)
59+
repo.remote(name=remote_name).set_url(github_repo_url)
60+
61+
# Check if the branch already exists
62+
if branch in repo.heads:
63+
repo.git.checkout(branch)
64+
else:
65+
raise ValueError(f"The branch {branch} you want save does not exist.")
66+
67+
# Add all files to the repo and commit if not already committed
68+
if not repo.is_dirty(untracked_files=True):
69+
repo.git.add(A=True)
70+
repo.index.commit("AI generated code.")
71+
72+
# Push to the GitHub repository
73+
origin = repo.remote(name=remote_name)
74+
try:
75+
origin.push(refspec=f"{branch}:{branch}")
76+
logger.info(f"Pushed to {github_repo_url} on branch {branch}")
77+
except Exception as e:
78+
raise Exception(
79+
f"Push {branch} to {organization}/{repo_name} fails.\n{str(e)}"
80+
)
81+
82+
83+
__all__ = []

commit0/harness/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import logging
66
import socket
77
import os
8+
import time
89
import requests
10+
from typing import Optional
11+
12+
from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError # type: ignore
13+
from ghapi.core import GhApi
914
from commit0.harness.constants import EVAL_BACKENDS
1015

1116

@@ -170,4 +175,29 @@ def create_branch(repo: git.Repo, branch: str, logger: logging.Logger) -> None:
170175
raise RuntimeError(f"Failed to create or switch to branch '{branch}': {e}")
171176

172177

178+
def create_repo_on_github(
179+
organization: str, repo: str, logger: logging.Logger, token: Optional[str] = None
180+
) -> None:
181+
api = GhApi(token=token)
182+
while True:
183+
try:
184+
api.repos.get(owner=organization, repo=repo) # type: ignore
185+
logger.info(f"{organization}/{repo} already exists")
186+
break
187+
except HTTP403ForbiddenError:
188+
while True:
189+
rl = api.rate_limit.get() # type: ignore
190+
logger.info(
191+
f"Rate limit exceeded for the current GitHub token,"
192+
f"waiting for 5 minutes, remaining calls: {rl.resources.core.remaining}"
193+
)
194+
if rl.resources.core.remaining > 0:
195+
break
196+
time.sleep(60 * 5)
197+
except HTTP404NotFoundError:
198+
api.repos.create_in_org(org=organization, name=repo) # type: ignore
199+
logger.info(f"Created {organization}/{repo} on GitHub")
200+
break
201+
202+
173203
__all__ = []

0 commit comments

Comments
 (0)