Skip to content

Commit 456df3b

Browse files
committed
pre-commit fixes
1 parent 08e6cf5 commit 456df3b

File tree

4 files changed

+99
-33
lines changed

4 files changed

+99
-33
lines changed

examples/star/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import List
22
from datasets import Dataset
33
from vllm import LLM, SamplingParams
4-
from utils import generate_prompt, cleanup
4+
from examples.star.utils import generate_prompt, cleanup
55

66

77
def generate_predictions(

examples/star/star.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,37 @@
22

33
from copy import deepcopy
44
from datasets import Dataset, DatasetDict, load_dataset
5-
from inference import generate_predictions
6-
from train import train
7-
from utils import execute_tests, format_solution, generate_prompt, parse_args
5+
from examples.star.inference import generate_predictions
6+
from examples.star.train import train
7+
from examples.star.utils import (
8+
execute_tests,
9+
format_solution,
10+
generate_prompt,
11+
parse_args,
12+
)
813

914

10-
def main():
15+
def main() -> None:
1116
args = parse_args()
1217
ds = load_dataset(args.dataset_name, args.dataset_config_name)
1318
assert "train" in ds
1419
# format the dataset for training and evaluation
1520
for split in ds:
1621
texts = []
17-
if split == "train": continue
22+
if split == "train":
23+
continue
1824
for example in ds[split]:
1925
canonical_solution = f"```python\n{example['canonical_solution']}\n```"
20-
text = [{"role": "user", "message": generate_prompt(example["prompt"], example["test"])}, {"role": "assistant", "message": format_solution(canonical_solution, example["prompt"])}]
26+
text = [
27+
{
28+
"role": "user",
29+
"message": generate_prompt(example["prompt"], example["test"]),
30+
},
31+
{
32+
"role": "assistant",
33+
"message": format_solution(canonical_solution, example["prompt"]),
34+
},
35+
]
2136
texts.append(text)
2237
ds[split] = ds[split].add_column(name="text", column=texts)
2338

@@ -28,23 +43,45 @@ def main():
2843
all_samples = generate_predictions(
2944
model_name, ds["train"], args.temperature, args.n
3045
)
31-
ds["train"].add_column(name="sample", column=all_samples).to_json(f"{output_dir}/data/samples-iter{i}.json")
46+
ds["train"].add_column(name="sample", column=all_samples).to_json(
47+
f"{output_dir}/data/samples-iter{i}.json"
48+
)
3249
assert len(ds["train"]) == len(all_samples)
3350

3451
# verify and construct the training set
35-
all_traces, all_execution_results = execute_tests(ds["train"], all_samples, max_workers=args.max_workers)
52+
all_traces, all_execution_results = execute_tests(
53+
ds["train"], all_samples, max_workers=args.max_workers
54+
)
3655
passed_examples = []
3756
for example, execution_results, samples in zip(
3857
ds["train"], all_execution_results, all_samples
3958
):
4059
for execution_result, sample in zip(execution_results, samples):
4160
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
4261
if execution_result == 0:
43-
example["text"] = [{"role": "user", "message": generate_prompt(example["prompt"], example["test"])}, {"role": "assistant", "message": format_solution(sample, example["prompt"])}]
62+
example["text"] = [
63+
{
64+
"role": "user",
65+
"message": generate_prompt(
66+
example["prompt"], example["test"]
67+
),
68+
},
69+
{
70+
"role": "assistant",
71+
"message": format_solution(sample, example["prompt"]),
72+
},
73+
]
4474
passed_examples.append(example)
4575
break
46-
raw_datasets = DatasetDict({"train": Dataset.from_list(passed_examples), "validation": ds["validation"]})
47-
raw_datasets["train"].to_json(f"{output_dir}/data/verified-samples-iter{i}.json")
76+
raw_datasets = DatasetDict(
77+
{
78+
"train": Dataset.from_list(passed_examples),
79+
"validation": ds["validation"],
80+
}
81+
)
82+
raw_datasets["train"].to_json(
83+
f"{output_dir}/data/verified-samples-iter{i}.json"
84+
)
4885

4986
# train
5087
args.output_dir = f"{output_dir}/models-iter{i}"
@@ -54,3 +91,6 @@ def main():
5491

5592
if __name__ == "__main__":
5693
main()
94+
95+
96+
__all__ = []

examples/star/train.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
"""
2222
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
2323

24-
import argparse
2524
import json
2625
import logging
2726
import math
@@ -34,28 +33,26 @@
3433
from accelerate import Accelerator
3534
from accelerate.logging import get_logger
3635
from accelerate.utils import set_seed
37-
from datasets import load_dataset
3836
from huggingface_hub import HfApi
3937
from torch.utils.data import DataLoader
4038
from tqdm.auto import tqdm
4139

4240
import transformers
4341
from transformers import (
44-
CONFIG_MAPPING,
4542
AutoConfig,
4643
AutoModelForCausalLM,
4744
AutoTokenizer,
4845
DataCollatorForSeq2Seq,
4946
get_scheduler,
5047
)
5148

52-
from utils import cleanup
49+
from examples.star.utils import cleanup
5350

5451

5552
logger = get_logger(__name__)
5653

5754

58-
def train(raw_datasets, model_name_or_path, args):
55+
def train(raw_datasets, model_name_or_path, args) -> None:
5956
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
6057
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
6158
# in the environment
@@ -289,7 +286,9 @@ def tokenize_function(examples):
289286
for step, batch in enumerate(active_dataloader):
290287
with accelerator.accumulate(model):
291288
batch["labels"] = batch["input_ids"].clone().detach()
292-
indices = (batch["input_ids"] == tokenizer.eos_token_id).cumsum(dim=1) == 0
289+
indices = (batch["input_ids"] == tokenizer.eos_token_id).cumsum(
290+
dim=1
291+
) == 0
293292
batch["labels"][indices] = -100
294293
outputs = model(**batch)
295294
loss = outputs.loss
@@ -323,7 +322,9 @@ def tokenize_function(examples):
323322
for step, batch in enumerate(eval_dataloader):
324323
with torch.no_grad():
325324
batch["labels"] = batch["input_ids"].clone().detach()
326-
indices = (batch["input_ids"] == tokenizer.eos_token_id).cumsum(dim=1) == 0
325+
indices = (batch["input_ids"] == tokenizer.eos_token_id).cumsum(
326+
dim=1
327+
) == 0
327328
batch["labels"][indices] = -100
328329
outputs = model(**batch)
329330

@@ -405,5 +406,4 @@ def tokenize_function(examples):
405406
cleanup(model)
406407

407408

408-
if __name__ == "__main__":
409-
main()
409+
__all__ = []

examples/star/utils.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from concurrent.futures import ThreadPoolExecutor, as_completed
55
from datasets import Dataset
66
from tqdm import tqdm
7-
from typing import List, Tuple
8-
from transformers import MODEL_MAPPING, SchedulerType
7+
from typing import Any, List, Tuple
8+
from transformers import SchedulerType
99
from commit0.harness.utils import extract_code_blocks
1010

1111

@@ -30,7 +30,7 @@ def execute_tests(
3030
3131
Args:
3232
----
33-
ds (Dataset): A Dataset object.
33+
examples (Dataset): A Dataset object.
3434
all_samples (List[List[str]]): A 2D list of strings, where `all_samples[i]` corresponds to the samples associated with `ds[i]`.
3535
max_workers (int): The number of worker threads to use for parallel execution. Default is 100.
3636
@@ -82,9 +82,7 @@ def execute_tests(
8282

8383

8484
def generate_prompt(prompt: str, test: str) -> str:
85-
"""
86-
Generate a Python code request prompt string.
87-
"""
85+
"""Generate a Python code request prompt string."""
8886
return f"""Write a Python function implementation for the following prompt:
8987
9088
{prompt}
@@ -100,7 +98,19 @@ def generate_prompt(prompt: str, test: str) -> str:
10098
"""
10199

102100

103-
def format_solution(text, prompt):
101+
def format_solution(text: str, prompt: str) -> str:
102+
"""Extracts a code block from the given text and formats it as a Python code snippet.
103+
104+
Args:
105+
----
106+
text (str): The input text which may contain code blocks.
107+
prompt (str): A string that will be returned if no code block is found.
108+
109+
Returns:
110+
-------
111+
str: A formatted code snippet if a code block exists, otherwise the prompt and text.
112+
113+
"""
104114
matches = extract_code_blocks(text)
105115
if len(matches) > 0:
106116
solution = matches[0]
@@ -110,7 +120,14 @@ def format_solution(text, prompt):
110120
return solution
111121

112122

113-
def parse_args():
123+
def parse_args() -> argparse.Namespace:
124+
"""Parse command-line arguments.
125+
126+
Returns
127+
-------
128+
argparse.Namespace: The parsed command-line arguments.
129+
130+
"""
114131
parser = argparse.ArgumentParser(
115132
description="Finetune a transformers model on a causal language modeling task"
116133
)
@@ -279,23 +296,32 @@ def parse_args():
279296
return args
280297

281298

282-
def cleanup(model, vllm=False):
283-
"""
284-
Clean up resources associated with the given model.
299+
def cleanup(model: Any, vllm: bool = False) -> None:
300+
"""Clean up resources associated with the given model.
285301
286302
Parameters
287303
----------
288304
model : Any
289305
The model object whose resources are to be cleaned up.
306+
vllm : Boolean
307+
The model object whose resources are to be cleaned up.
308+
309+
Returns
310+
-------
311+
None
312+
290313
"""
291314
try:
292315
import torch
293316
import contextlib
317+
294318
if torch.cuda.is_available():
295319
if vllm:
296320
from vllm.distributed.parallel_state import (
297-
destroy_model_parallel, destroy_distributed_environment
321+
destroy_model_parallel,
322+
destroy_distributed_environment,
298323
)
324+
299325
destroy_model_parallel()
300326
destroy_distributed_environment()
301327
del model.llm_engine.model_executor

0 commit comments

Comments
 (0)