|
41 | 41 | """
|
42 | 42 | from functools import partial
|
43 | 43 | import os
|
| 44 | +import tempfile |
| 45 | +from pathlib import Path |
44 | 46 | import torch
|
45 | 47 | import torch.nn as nn
|
46 | 48 | import torch.nn.functional as F
|
|
57 | 59 | sys.stdout.fileno = lambda: 0
|
58 | 60 | # sphinx_gallery_end_ignore
|
59 | 61 | from ray import tune
|
60 |
| -from ray.air import Checkpoint, session |
| 62 | +from ray import train |
| 63 | +from ray.train import Checkpoint, get_checkpoint |
61 | 64 | from ray.tune.schedulers import ASHAScheduler
|
62 |
| - |
63 |
| -# TODO: Migrate to ray.train.Checkpoint and remove following line |
64 |
| -os.environ["RAY_AIR_NEW_PERSISTENCE_MODE"]="0" |
| 65 | +import ray.cloudpickle as pickle |
65 | 66 |
|
66 | 67 | ######################################################################
|
67 |
| -# Most of the imports are needed for building the PyTorch model. Only the last three |
| 68 | +# Most of the imports are needed for building the PyTorch model. Only the last |
68 | 69 | # imports are for Ray Tune.
|
69 | 70 | #
|
70 | 71 | # Data loaders
|
@@ -135,13 +136,15 @@ def forward(self, x):
|
135 | 136 | #
|
136 | 137 | # net = Net(config["l1"], config["l2"])
|
137 | 138 | #
|
138 |
| -# checkpoint = session.get_checkpoint() |
139 |
| -# |
| 139 | +# checkpoint = get_checkpoint() |
140 | 140 | # if checkpoint:
|
141 |
| -# checkpoint_state = checkpoint.to_dict() |
142 |
| -# start_epoch = checkpoint_state["epoch"] |
143 |
| -# net.load_state_dict(checkpoint_state["net_state_dict"]) |
144 |
| -# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"]) |
| 141 | +# with checkpoint.as_directory() as checkpoint_dir: |
| 142 | +# data_path = Path(checkpoint_dir) / "data.pkl" |
| 143 | +# with open(data_path, "rb") as fp: |
| 144 | +# checkpoint_state = pickle.load(fp) |
| 145 | +# start_epoch = checkpoint_state["epoch"] |
| 146 | +# net.load_state_dict(checkpoint_state["net_state_dict"]) |
| 147 | +# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"]) |
145 | 148 | # else:
|
146 | 149 | # start_epoch = 0
|
147 | 150 | #
|
@@ -197,12 +200,16 @@ def forward(self, x):
|
197 | 200 | # "net_state_dict": net.state_dict(),
|
198 | 201 | # "optimizer_state_dict": optimizer.state_dict(),
|
199 | 202 | # }
|
200 |
| -# checkpoint = Checkpoint.from_dict(checkpoint_data) |
| 203 | +# with tempfile.TemporaryDirectory() as checkpoint_dir: |
| 204 | +# data_path = Path(checkpoint_dir) / "data.pkl" |
| 205 | +# with open(data_path, "wb") as fp: |
| 206 | +# pickle.dump(checkpoint_data, fp) |
201 | 207 | #
|
202 |
| -# session.report( |
203 |
| -# {"loss": val_loss / val_steps, "accuracy": correct / total}, |
204 |
| -# checkpoint=checkpoint, |
205 |
| -# ) |
| 208 | +# checkpoint = Checkpoint.from_directory(checkpoint_dir) |
| 209 | +# train.report( |
| 210 | +# {"loss": val_loss / val_steps, "accuracy": correct / total}, |
| 211 | +# checkpoint=checkpoint, |
| 212 | +# ) |
206 | 213 | #
|
207 | 214 | # Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically,
|
208 | 215 | # we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics
|
@@ -236,13 +243,15 @@ def train_cifar(config, data_dir=None):
|
236 | 243 | criterion = nn.CrossEntropyLoss()
|
237 | 244 | optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
238 | 245 |
|
239 |
| - checkpoint = session.get_checkpoint() |
240 |
| - |
| 246 | + checkpoint = get_checkpoint() |
241 | 247 | if checkpoint:
|
242 |
| - checkpoint_state = checkpoint.to_dict() |
243 |
| - start_epoch = checkpoint_state["epoch"] |
244 |
| - net.load_state_dict(checkpoint_state["net_state_dict"]) |
245 |
| - optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"]) |
| 248 | + with checkpoint.as_directory() as checkpoint_dir: |
| 249 | + data_path = Path(checkpoint_dir) / "data.pkl" |
| 250 | + with open(data_path, "rb") as fp: |
| 251 | + checkpoint_state = pickle.load(fp) |
| 252 | + start_epoch = checkpoint_state["epoch"] |
| 253 | + net.load_state_dict(checkpoint_state["net_state_dict"]) |
| 254 | + optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"]) |
246 | 255 | else:
|
247 | 256 | start_epoch = 0
|
248 | 257 |
|
@@ -311,12 +320,17 @@ def train_cifar(config, data_dir=None):
|
311 | 320 | "net_state_dict": net.state_dict(),
|
312 | 321 | "optimizer_state_dict": optimizer.state_dict(),
|
313 | 322 | }
|
314 |
| - checkpoint = Checkpoint.from_dict(checkpoint_data) |
315 |
| - |
316 |
| - session.report( |
317 |
| - {"loss": val_loss / val_steps, "accuracy": correct / total}, |
318 |
| - checkpoint=checkpoint, |
319 |
| - ) |
| 323 | + with tempfile.TemporaryDirectory() as checkpoint_dir: |
| 324 | + data_path = Path(checkpoint_dir) / "data.pkl" |
| 325 | + with open(data_path, "wb") as fp: |
| 326 | + pickle.dump(checkpoint_data, fp) |
| 327 | + |
| 328 | + checkpoint = Checkpoint.from_directory(checkpoint_dir) |
| 329 | + train.report( |
| 330 | + {"loss": val_loss / val_steps, "accuracy": correct / total}, |
| 331 | + checkpoint=checkpoint, |
| 332 | + ) |
| 333 | + |
320 | 334 | print("Finished Training")
|
321 | 335 |
|
322 | 336 |
|
@@ -449,13 +463,15 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
|
449 | 463 | best_trained_model = nn.DataParallel(best_trained_model)
|
450 | 464 | best_trained_model.to(device)
|
451 | 465 |
|
452 |
| - best_checkpoint = best_trial.checkpoint.to_air_checkpoint() |
453 |
| - best_checkpoint_data = best_checkpoint.to_dict() |
454 |
| - |
455 |
| - best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"]) |
| 466 | + best_checkpoint = result.get_best_checkpoint(trial=best_trial, metric="accuracy", mode="max") |
| 467 | + with best_checkpoint.as_directory() as checkpoint_dir: |
| 468 | + data_path = Path(checkpoint_dir) / "data.pkl" |
| 469 | + with open(data_path, "rb") as fp: |
| 470 | + best_checkpoint_data = pickle.load(fp) |
456 | 471 |
|
457 |
| - test_acc = test_accuracy(best_trained_model, device) |
458 |
| - print("Best trial test set accuracy: {}".format(test_acc)) |
| 472 | + best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"]) |
| 473 | + test_acc = test_accuracy(best_trained_model, device) |
| 474 | + print("Best trial test set accuracy: {}".format(test_acc)) |
459 | 475 |
|
460 | 476 |
|
461 | 477 | if __name__ == "__main__":
|
|
0 commit comments