1
1
"""Main STaR Loop"""
2
2
3
- from datasets import Dataset , load_dataset
3
+ from datasets import Dataset , DatasetDict , load_dataset
4
4
from inference import generate_predictions
5
5
from train import train
6
- from utils import execute_tests , parse_args
6
+ from utils import execute_tests , format_solution , generate_prompt , parse_args
7
7
8
8
9
9
def main ():
10
10
args = parse_args ()
11
- ds = load_dataset (args .dataset_name )
11
+ ds = load_dataset (args .dataset_name , args . dataset_config_name )
12
12
assert "train" in ds
13
+ # format the dataset for training and evaluation
14
+ for split in ds :
15
+ texts = []
16
+ if split == "train" : continue
17
+ for example in ds [split ]:
18
+ canonical_solution = f"```python\n { example ['canonical_solution' ]} \n ```"
19
+ text = [{"role" : "user" , "message" : generate_prompt (example ["prompt" ], example ["test" ])}, {"role" : "assistant" , "message" : format_solution (canonical_solution , example ["prompt" ])}]
20
+ texts .append (text )
21
+ print (text )
22
+ ds [split ] = ds [split ].add_column (name = "text" , column = texts )
23
+ ds ["train" ] = ds ["train" ].select (range (10 ))
24
+
25
+ # sample
13
26
all_samples = generate_predictions (
14
27
args .model_name_or_path , ds ["train" ], args .temperature , args .n
15
28
)
16
29
assert len (ds ["train" ]) == len (all_samples )
30
+
31
+ # verify and construct the training set
17
32
all_traces , all_execution_results = execute_tests (ds ["train" ], all_samples )
18
33
passed_examples = []
19
34
for example , execution_results , samples in zip (
@@ -22,13 +37,13 @@ def main():
22
37
for execution_result , sample in zip (execution_results , samples ):
23
38
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
24
39
if execution_result == 0 :
25
- example ["prediction " ] = sample
40
+ example ["text " ] = [{ "role" : "user" , "message" : generate_prompt ( example [ "prompt" ], example [ "test" ])}, { "role" : "assistant" , "message" : format_solution ( sample , example [ "prompt" ])}]
26
41
passed_examples .append (example )
27
42
break
28
- new_ds = Dataset .from_list (passed_examples )
29
- new_ds . to_json ( "star_training.json" )
30
- print ( len ( passed_examples ) / len ( ds [ " train" ]))
31
- train (args )
43
+ raw_datasets = DatasetDict ({ "train" : Dataset .from_list (passed_examples ), "validation" : ds [ "validation" ]} )
44
+
45
+ # train
46
+ train (raw_datasets , args . model_name_or_path , args )
32
47
33
48
34
49
if __name__ == "__main__" :
0 commit comments