Skip to content

Commit 68b724b

Browse files
committed
training works
1 parent f6b2a71 commit 68b724b

File tree

3 files changed

+36
-24
lines changed

3 files changed

+36
-24
lines changed

examples/star/star.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,34 @@
11
"""Main STaR Loop"""
22

3-
from datasets import Dataset, load_dataset
3+
from datasets import Dataset, DatasetDict, load_dataset
44
from inference import generate_predictions
55
from train import train
6-
from utils import execute_tests, parse_args
6+
from utils import execute_tests, format_solution, generate_prompt, parse_args
77

88

99
def main():
1010
args = parse_args()
11-
ds = load_dataset(args.dataset_name)
11+
ds = load_dataset(args.dataset_name, args.dataset_config_name)
1212
assert "train" in ds
13+
# format the dataset for training and evaluation
14+
for split in ds:
15+
texts = []
16+
if split == "train": continue
17+
for example in ds[split]:
18+
canonical_solution = f"```python\n{example['canonical_solution']}\n```"
19+
text = [{"role": "user", "message": generate_prompt(example["prompt"], example["test"])}, {"role": "assistant", "message": format_solution(canonical_solution, example["prompt"])}]
20+
texts.append(text)
21+
print(text)
22+
ds[split] = ds[split].add_column(name="text", column=texts)
23+
ds["train"] = ds["train"].select(range(10))
24+
25+
# sample
1326
all_samples = generate_predictions(
1427
args.model_name_or_path, ds["train"], args.temperature, args.n
1528
)
1629
assert len(ds["train"]) == len(all_samples)
30+
31+
# verify and construct the training set
1732
all_traces, all_execution_results = execute_tests(ds["train"], all_samples)
1833
passed_examples = []
1934
for example, execution_results, samples in zip(
@@ -22,13 +37,13 @@ def main():
2237
for execution_result, sample in zip(execution_results, samples):
2338
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
2439
if execution_result == 0:
25-
example["prediction"] = sample
40+
example["text"] = [{"role": "user", "message": generate_prompt(example["prompt"], example["test"])}, {"role": "assistant", "message": format_solution(sample, example["prompt"])}]
2641
passed_examples.append(example)
2742
break
28-
new_ds = Dataset.from_list(passed_examples)
29-
new_ds.to_json("star_training.json")
30-
print(len(passed_examples) / len(ds["train"]))
31-
train(args)
43+
raw_datasets = DatasetDict({"train": Dataset.from_list(passed_examples), "validation": ds["validation"]})
44+
45+
# train
46+
train(raw_datasets, args.model_name_or_path, args)
3247

3348

3449
if __name__ == "__main__":

examples/star/train.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
import datasets
3333
import torch
34-
from accelerate import Accelerator, DistributedType
34+
from accelerate import Accelerator
3535
from accelerate.logging import get_logger
3636
from accelerate.utils import set_seed
3737
from datasets import load_dataset
@@ -234,10 +234,6 @@ def tokenize_function(examples):
234234
)
235235
)
236236

237-
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
238-
if accelerator.distributed_type == DistributedType.TPU:
239-
model.tie_weights()
240-
241237
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
242238
num_update_steps_per_epoch = math.ceil(
243239
len(train_dataloader) / args.gradient_accumulation_steps
@@ -291,17 +287,7 @@ def tokenize_function(examples):
291287
model.train()
292288
if args.with_tracking:
293289
total_loss = 0
294-
if (
295-
args.resume_from_checkpoint
296-
and epoch == starting_epoch
297-
and resume_step is not None
298-
):
299-
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
300-
active_dataloader = accelerator.skip_first_batches(
301-
train_dataloader, resume_step
302-
)
303-
else:
304-
active_dataloader = train_dataloader
290+
active_dataloader = train_dataloader
305291
for step, batch in enumerate(active_dataloader):
306292
with accelerator.accumulate(model):
307293
batch["labels"] = batch["input_ids"].clone().detach()

examples/star/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tqdm import tqdm
66
from typing import List, Tuple
77
from transformers import MODEL_MAPPING, SchedulerType
8+
from commit0.harness.utils import extract_code_blocks
89

910

1011
def execute_tests(
@@ -98,6 +99,16 @@ def generate_prompt(prompt: str, test: str) -> str:
9899
"""
99100

100101

102+
def format_solution(text, prompt):
103+
matches = extract_code_blocks(text)
104+
if len(matches) > 0:
105+
solution = matches[0]
106+
solution = f"```python\n{solution}\n```"
107+
else:
108+
solution = prompt + "\n\n" + text
109+
return solution
110+
111+
101112
def parse_args():
102113
parser = argparse.ArgumentParser(
103114
description="Finetune a transformers model on a causal language modeling task"

0 commit comments

Comments
 (0)