Skip to content

Commit 9cec1bc

Browse files
author
atsuko
committed
Use get_best_checkpoint()
1 parent 5262b4c commit 9cec1bc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

beginner_source/hyperparameter_tuning_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
463463
best_trained_model = nn.DataParallel(best_trained_model)
464464
best_trained_model.to(device)
465465

466-
best_checkpoint = best_trial.checkpoint
466+
best_checkpoint = result.get_best_checkpoint(trial=best_trial, metric="accuracy", mode="max")
467467
with best_checkpoint.as_directory() as checkpoint_dir:
468468
data_path = Path(checkpoint_dir) / "data.pkl"
469469
with open(data_path, "rb") as fp:

0 commit comments

Comments
 (0)