Skip to content

Commit 33b15df

Browse files
atskaekrfrickesvekars
authored
Fix Checkpoint in Hyperparameter Tuning (#2782)
* Save/read Checkpoint to file --------- Co-authored-by: Kai Fricke <[email protected]> Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 5e772fa commit 33b15df

File tree

1 file changed

+50
-34
lines changed

1 file changed

+50
-34
lines changed

beginner_source/hyperparameter_tuning_tutorial.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
"""
4242
from functools import partial
4343
import os
44+
import tempfile
45+
from pathlib import Path
4446
import torch
4547
import torch.nn as nn
4648
import torch.nn.functional as F
@@ -57,14 +59,13 @@
5759
sys.stdout.fileno = lambda: 0
5860
# sphinx_gallery_end_ignore
5961
from ray import tune
60-
from ray.air import Checkpoint, session
62+
from ray import train
63+
from ray.train import Checkpoint, get_checkpoint
6164
from ray.tune.schedulers import ASHAScheduler
62-
63-
# TODO: Migrate to ray.train.Checkpoint and remove following line
64-
os.environ["RAY_AIR_NEW_PERSISTENCE_MODE"]="0"
65+
import ray.cloudpickle as pickle
6566

6667
######################################################################
67-
# Most of the imports are needed for building the PyTorch model. Only the last three
68+
# Most of the imports are needed for building the PyTorch model. Only the last
6869
# imports are for Ray Tune.
6970
#
7071
# Data loaders
@@ -135,13 +136,15 @@ def forward(self, x):
135136
#
136137
# net = Net(config["l1"], config["l2"])
137138
#
138-
# checkpoint = session.get_checkpoint()
139-
#
139+
# checkpoint = get_checkpoint()
140140
# if checkpoint:
141-
# checkpoint_state = checkpoint.to_dict()
142-
# start_epoch = checkpoint_state["epoch"]
143-
# net.load_state_dict(checkpoint_state["net_state_dict"])
144-
# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
141+
# with checkpoint.as_directory() as checkpoint_dir:
142+
# data_path = Path(checkpoint_dir) / "data.pkl"
143+
# with open(data_path, "rb") as fp:
144+
# checkpoint_state = pickle.load(fp)
145+
# start_epoch = checkpoint_state["epoch"]
146+
# net.load_state_dict(checkpoint_state["net_state_dict"])
147+
# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
145148
# else:
146149
# start_epoch = 0
147150
#
@@ -197,12 +200,16 @@ def forward(self, x):
197200
# "net_state_dict": net.state_dict(),
198201
# "optimizer_state_dict": optimizer.state_dict(),
199202
# }
200-
# checkpoint = Checkpoint.from_dict(checkpoint_data)
203+
# with tempfile.TemporaryDirectory() as checkpoint_dir:
204+
# data_path = Path(checkpoint_dir) / "data.pkl"
205+
# with open(data_path, "wb") as fp:
206+
# pickle.dump(checkpoint_data, fp)
201207
#
202-
# session.report(
203-
# {"loss": val_loss / val_steps, "accuracy": correct / total},
204-
# checkpoint=checkpoint,
205-
# )
208+
# checkpoint = Checkpoint.from_directory(checkpoint_dir)
209+
# train.report(
210+
# {"loss": val_loss / val_steps, "accuracy": correct / total},
211+
# checkpoint=checkpoint,
212+
# )
206213
#
207214
# Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically,
208215
# we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics
@@ -236,13 +243,15 @@ def train_cifar(config, data_dir=None):
236243
criterion = nn.CrossEntropyLoss()
237244
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
238245

239-
checkpoint = session.get_checkpoint()
240-
246+
checkpoint = get_checkpoint()
241247
if checkpoint:
242-
checkpoint_state = checkpoint.to_dict()
243-
start_epoch = checkpoint_state["epoch"]
244-
net.load_state_dict(checkpoint_state["net_state_dict"])
245-
optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
248+
with checkpoint.as_directory() as checkpoint_dir:
249+
data_path = Path(checkpoint_dir) / "data.pkl"
250+
with open(data_path, "rb") as fp:
251+
checkpoint_state = pickle.load(fp)
252+
start_epoch = checkpoint_state["epoch"]
253+
net.load_state_dict(checkpoint_state["net_state_dict"])
254+
optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
246255
else:
247256
start_epoch = 0
248257

@@ -311,12 +320,17 @@ def train_cifar(config, data_dir=None):
311320
"net_state_dict": net.state_dict(),
312321
"optimizer_state_dict": optimizer.state_dict(),
313322
}
314-
checkpoint = Checkpoint.from_dict(checkpoint_data)
315-
316-
session.report(
317-
{"loss": val_loss / val_steps, "accuracy": correct / total},
318-
checkpoint=checkpoint,
319-
)
323+
with tempfile.TemporaryDirectory() as checkpoint_dir:
324+
data_path = Path(checkpoint_dir) / "data.pkl"
325+
with open(data_path, "wb") as fp:
326+
pickle.dump(checkpoint_data, fp)
327+
328+
checkpoint = Checkpoint.from_directory(checkpoint_dir)
329+
train.report(
330+
{"loss": val_loss / val_steps, "accuracy": correct / total},
331+
checkpoint=checkpoint,
332+
)
333+
320334
print("Finished Training")
321335

322336

@@ -449,13 +463,15 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
449463
best_trained_model = nn.DataParallel(best_trained_model)
450464
best_trained_model.to(device)
451465

452-
best_checkpoint = best_trial.checkpoint.to_air_checkpoint()
453-
best_checkpoint_data = best_checkpoint.to_dict()
454-
455-
best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])
466+
best_checkpoint = result.get_best_checkpoint(trial=best_trial, metric="accuracy", mode="max")
467+
with best_checkpoint.as_directory() as checkpoint_dir:
468+
data_path = Path(checkpoint_dir) / "data.pkl"
469+
with open(data_path, "rb") as fp:
470+
best_checkpoint_data = pickle.load(fp)
456471

457-
test_acc = test_accuracy(best_trained_model, device)
458-
print("Best trial test set accuracy: {}".format(test_acc))
472+
best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])
473+
test_acc = test_accuracy(best_trained_model, device)
474+
print("Best trial test set accuracy: {}".format(test_acc))
459475

460476

461477
if __name__ == "__main__":

0 commit comments

Comments
 (0)