Skip to content

Commit 3268da2

Browse files
authored
[eval_llama] Add option to save checkpoint after eager transforms.
Differential Revision: D62150021 Pull Request resolved: #5045
1 parent 32d83b0 commit 3268da2

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,15 @@ def gen_eval_wrapper(
158158
else manager.model.eval().to(device="cpu")
159159
)
160160

161+
# Save the checkpoint after the eager model preparation is done.
162+
# The reason for this option is that the checkpoint can be used
163+
# to do evaluations in other evaluation platforms, or with data
164+
# that is not available in this eval_llama. We save the checkpoint
165+
# here for consistency with eval_llama. The accuracy results we
166+
# get from eval_llama can be used as a reference to other evaluations.
167+
if args.output_eager_checkpoint_file is not None:
168+
torch.save(model, args.output_eager_checkpoint_file)
169+
161170
return EagerEvalWrapper(
162171
model=model,
163172
tokenizer=tokenizer,
@@ -196,6 +205,12 @@ def build_args_parser() -> argparse.ArgumentParser:
196205
default=None,
197206
help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime",
198207
)
208+
parser.add_argument(
209+
"--output_eager_checkpoint_file",
210+
type=str,
211+
default=None,
212+
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
213+
)
199214

200215
return parser
201216

0 commit comments

Comments
 (0)