Skip to content

Commit 0176605

Browse files
authored
Support for spot training. (aws#303)
* Support for spot training. * Updated the comment * Fixing the CI build for pre-commit failures. * Fixing the CI. * Adding the test file for testing spot training * Support for spot training. Addressed the review comments. Added the test script. * Added the log statement as per review comment. * Fixing the test to run correctly in CI * Emit the end of training file only if the job is not running under SageMaker * Reorganized code for better readability. * Updated the implementation to avoid global variables. * Addressed the review comments * Addressed the review comments to refactor the code. * Updated the checkpoint timestamp to look for all the modified files in the directory. * Avoided one disk access to compute timestamp
1 parent fdf817c commit 0176605

File tree

6 files changed

+386
-2
lines changed

6 files changed

+386
-2
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"LocalPath" : "./savedParams"
3+
}

tests/mxnet/test_spot_training.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Using batch size 4 instead of 1024 decreases runtime from 35 secs to 4 secs.
2+
3+
from mxnet import gluon, init, autograd
4+
from mxnet.gluon import nn
5+
from mxnet.gluon.data.vision import datasets, transforms
6+
import time
7+
import mxnet as mx
8+
from tornasole import modes
9+
from tornasole.mxnet.hook import TornasoleHook as t_hook
10+
from tornasole import SaveConfig
11+
from tornasole.mxnet import reset_collections
12+
from tornasole.core.access_layer.utils import has_training_ended
13+
from tornasole.core.config_constants import CHECKPOINT_CONFIG_FILE_PATH_ENV_VAR
14+
from tornasole.trials import create_trial
15+
from datetime import datetime
16+
17+
import shutil
18+
import os
19+
20+
21+
def acc(output, label):
22+
return (output.argmax(axis=1) == label.astype("float32")).mean().asscalar()
23+
24+
25+
def run_mnist(
26+
hook=None,
27+
set_modes=False,
28+
num_steps_train=None,
29+
num_steps_eval=None,
30+
epochs=2,
31+
save_interval=None,
32+
save_path="./saveParams",
33+
):
34+
batch_size = 4
35+
normalize_mean = 0.13
36+
mnist_train = datasets.FashionMNIST(train=True)
37+
38+
X, y = mnist_train[0]
39+
("X shape: ", X.shape, "X dtype", X.dtype, "y:", y)
40+
41+
text_labels = [
42+
"t-shirt",
43+
"trouser",
44+
"pullover",
45+
"dress",
46+
"coat",
47+
"sandal",
48+
"shirt",
49+
"sneaker",
50+
"bag",
51+
"ankle boot",
52+
]
53+
transformer = transforms.Compose(
54+
[transforms.ToTensor(), transforms.Normalize(normalize_mean, 0.31)]
55+
)
56+
57+
mnist_train = mnist_train.transform_first(transformer)
58+
mnist_valid = gluon.data.vision.FashionMNIST(train=False)
59+
60+
train_data = gluon.data.DataLoader(
61+
mnist_train, batch_size=batch_size, shuffle=True, num_workers=4
62+
)
63+
valid_data = gluon.data.DataLoader(
64+
mnist_valid.transform_first(transformer), batch_size=batch_size, num_workers=4
65+
)
66+
67+
# Create Model in Gluon
68+
net = nn.HybridSequential()
69+
net.add(
70+
nn.Conv2D(channels=6, kernel_size=5, activation="relu"),
71+
nn.MaxPool2D(pool_size=2, strides=2),
72+
nn.Conv2D(channels=16, kernel_size=3, activation="relu"),
73+
nn.MaxPool2D(pool_size=2, strides=2),
74+
nn.Flatten(),
75+
nn.Dense(120, activation="relu"),
76+
nn.Dense(84, activation="relu"),
77+
nn.Dense(10),
78+
)
79+
net.initialize(init=init.Xavier(), ctx=mx.cpu())
80+
81+
if hook is not None:
82+
# Register the forward Hook
83+
hook.register_hook(net)
84+
85+
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
86+
trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.1})
87+
hook.register_hook(softmax_cross_entropy)
88+
89+
# Start the training.
90+
for epoch in range(epochs):
91+
train_loss, train_acc, valid_acc = 0.0, 0.0, 0.0
92+
tic = time.time()
93+
if set_modes:
94+
hook.set_mode(modes.TRAIN)
95+
96+
i = 0
97+
for data, label in train_data:
98+
data = data.as_in_context(mx.cpu(0))
99+
# forward + backward
100+
with autograd.record():
101+
output = net(data)
102+
loss = softmax_cross_entropy(output, label)
103+
loss.backward()
104+
# update parameters
105+
trainer.step(batch_size)
106+
# calculate training metrics
107+
train_loss += loss.mean().asscalar()
108+
train_acc += acc(output, label)
109+
i += 1
110+
if num_steps_train is not None and i >= num_steps_train:
111+
break
112+
# calculate validation accuracy
113+
if set_modes:
114+
hook.set_mode(modes.EVAL)
115+
i = 0
116+
for data, label in valid_data:
117+
data = data.as_in_context(mx.cpu(0))
118+
val_output = net(data)
119+
valid_acc += acc(val_output, label)
120+
loss = softmax_cross_entropy(val_output, label)
121+
i += 1
122+
if num_steps_eval is not None and i >= num_steps_eval:
123+
break
124+
print(
125+
"Epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec"
126+
% (
127+
epoch,
128+
train_loss / len(train_data),
129+
train_acc / len(train_data),
130+
valid_acc / len(valid_data),
131+
time.time() - tic,
132+
)
133+
)
134+
if save_interval is not None and (epoch % save_interval) == 0:
135+
net.save_parameters("{0}/params_{1}.params".format(save_path, epoch))
136+
137+
138+
def test_spot_hook():
139+
reset_collections()
140+
os.environ[
141+
CHECKPOINT_CONFIG_FILE_PATH_ENV_VAR
142+
] = "./tests/mxnet/test_json_configs/checkpointconfig.json"
143+
checkpoint_path = "./savedParams"
144+
if not os.path.exists(checkpoint_path):
145+
os.mkdir(checkpoint_path)
146+
save_config = SaveConfig(save_steps=[10, 11, 12, 13, 14, 40, 50, 60, 70, 80])
147+
148+
"""
149+
Run the training for 2 epochs and save the parameter after every epoch.
150+
We expect that steps 0 to 14 will be written.
151+
"""
152+
153+
run_id_1 = "trial_" + datetime.now().strftime("%Y%m%d-%H%M%S%f")
154+
out_dir_1 = "newlogsRunTest/" + run_id_1
155+
hook = t_hook(
156+
out_dir=out_dir_1, save_config=save_config, include_collections=["weights", "gradients"]
157+
)
158+
assert has_training_ended(out_dir_1) == False
159+
run_mnist(
160+
hook=hook,
161+
num_steps_train=10,
162+
num_steps_eval=10,
163+
epochs=2,
164+
save_interval=1,
165+
save_path=checkpoint_path,
166+
)
167+
168+
"""
169+
Run the training again for 4 epochs and save the parameter after every epoch.
170+
We DONOT expect that steps 0 to 14 are written.
171+
We expect to read steps 40, 50, 60, 70 and 80
172+
"""
173+
run_id_2 = "trial_" + datetime.now().strftime("%Y%m%d-%H%M%S%f")
174+
out_dir_2 = "newlogsRunTest/" + run_id_2
175+
hook = t_hook(
176+
out_dir=out_dir_2, save_config=save_config, include_collections=["weights", "gradients"]
177+
)
178+
assert has_training_ended(out_dir_2) == False
179+
run_mnist(
180+
hook=hook,
181+
num_steps_train=10,
182+
num_steps_eval=10,
183+
epochs=4,
184+
save_interval=1,
185+
save_path=checkpoint_path,
186+
)
187+
# Unset the environ variable before validation so that it won't affect the other scripts in py test environment.
188+
del os.environ[CHECKPOINT_CONFIG_FILE_PATH_ENV_VAR]
189+
190+
# Validation
191+
print("Created the trial with out_dir {0} for the first training".format(out_dir_1))
192+
tr = create_trial(out_dir_1)
193+
assert tr
194+
available_steps_1 = tr.available_steps()
195+
assert 40 not in available_steps_1
196+
assert 80 not in available_steps_1
197+
print(available_steps_1)
198+
199+
print("Created the trial with out_dir {0} for the second training".format(out_dir_2))
200+
tr = create_trial(out_dir_2)
201+
assert tr
202+
available_steps_2 = tr.available_steps()
203+
assert 40 in available_steps_2
204+
assert 50 in available_steps_2
205+
assert 60 in available_steps_2
206+
assert 70 in available_steps_2
207+
assert 80 in available_steps_2
208+
assert 0 not in available_steps_2
209+
assert 10 not in available_steps_2
210+
assert 11 not in available_steps_2
211+
assert 12 not in available_steps_2
212+
print(available_steps_2)
213+
214+
print("Cleaning up.")
215+
shutil.rmtree(os.path.dirname(out_dir_1))
216+
shutil.rmtree(checkpoint_path)

tornasole/core/access_layer/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@
55
from tornasole.core.utils import is_s3, get_region
66
from tornasole.core.logger import get_logger
77
from tornasole.core.access_layer.s3handler import S3Handler, ListRequest
8+
from tornasole.core.sagemaker_utils import is_sagemaker_job
89
import asyncio
910
import aioboto3
1011

11-
END_OF_JOB_FILENAME = "END_OF_JOB.ts"
12+
END_OF_JOB_FILENAME = "training_job_end.ts"
1213
logger = get_logger()
1314

1415

1516
def training_has_ended(trial_prefix):
17+
# Emit the end of training file only if the job is not running under SageMaker.
18+
if is_sagemaker_job():
19+
logger.info(
20+
f"The end of training job file will not be written for jobs running under SageMaker."
21+
)
22+
return
1623
try:
1724
check_dir_exists(trial_prefix)
1825
# if path does not exist, then we don't need to write a file

tornasole/core/config_constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@
1414
TORNASOLE_CONFIG_SAVE_ALL_KEY = "save_all"
1515
DEFAULT_SAGEMAKER_TORNASOLE_PATH = "/opt/ml/output/tensors"
1616
TORNASOLE_DEFAULT_COLLECTIONS_FILE_NAME = "worker_0_collections.json"
17+
CHECKPOINT_CONFIG_FILE_PATH_ENV_VAR = "CHECKPOINT_CONFIG_FILE_PATH"
18+
CHECKPOINT_DIR_KEY = "LocalPath"
19+
DEFAULT_CHECKPOINT_CONFIG_FILE = "/opt/ml/input/config/checkpointconfig.json"
20+
TORNASOLE_META_DATA_FILE = "TornasoleMetadata.json"
21+
LATEST_GLOBAL_STEP_SEEN = "latest-global-step-seen"
22+
LATEST_GLOBAL_STEP_SAVED = "latest-global-step-saved"
23+
LATEST_MODE_STEP = "latest-mode-step"
24+
TRAINING_RUN = "training-run"

tornasole/core/hook.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
from tornasole.core.logger import get_logger
2222
from tornasole.core.reductions import get_reduction_tensor_name
2323
from tornasole.core.writer import FileWriter
24+
from tornasole.core.state_store import StateStore
25+
from tornasole.core.config_constants import (
26+
TRAINING_RUN,
27+
LATEST_GLOBAL_STEP_SAVED,
28+
LATEST_GLOBAL_STEP_SEEN,
29+
LATEST_MODE_STEP,
30+
)
31+
2432

2533
logger = get_logger()
2634

@@ -126,13 +134,33 @@ def __init__(
126134
self.prepared_collections = False
127135
self.tensor_to_collections = {}
128136
self.step = init_step
137+
self.last_saved_step = None
129138
self.mode = ModeKeys.GLOBAL
130139
self.mode_steps = {ModeKeys.GLOBAL: init_step}
131140
self.writer = None
132141
self.tb_writers = {}
133142
self.logger.info("Saving to {}".format(self.out_dir))
134143
atexit.register(self._cleanup)
135144

145+
# Check if there is any last saved tornasole state. Initialize the hook based last saved state.
146+
self.training_run = 0
147+
self._initialize_to_last_saved_state()
148+
149+
def _initialize_to_last_saved_state(self):
150+
self.state_store = StateStore()
151+
last_tornasole_state = self.state_store.get_last_saved_tornasole_state()
152+
if last_tornasole_state is not None:
153+
self.last_saved_step = last_tornasole_state[LATEST_GLOBAL_STEP_SAVED]
154+
self.init_step = last_tornasole_state[LATEST_GLOBAL_STEP_SEEN]
155+
self.training_run = 1 + last_tornasole_state[TRAINING_RUN]
156+
for (mode, step) in last_tornasole_state[LATEST_MODE_STEP].items():
157+
self.mode_steps[ModeKeys[mode]] = step
158+
self.mode_steps[ModeKeys.GLOBAL] = self.init_step
159+
self.step = self.init_step
160+
self.logger.info(
161+
f"Initialized the hook with the last saved state: last_saved_step={self.last_saved_step} init_step = {self.init_step}, step = {self.step} mode_steps = {str(self.mode_steps)}"
162+
)
163+
136164
def __repr__(self):
137165
return (
138166
f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}>:(\n"
@@ -309,10 +337,25 @@ def _cleanup(self):
309337
training_has_ended(self.out_dir)
310338

311339
def _increment_step(self):
340+
# Update the last_tornasole_state to the last step number that was saved or seen
341+
self._write_tornasole_state()
342+
312343
self.step += 1
313344
self.mode_steps[self.mode] += 1
314345
self._collections_to_save_for_step = None
315346

347+
def _write_tornasole_state(self):
348+
if self.state_store.is_checkpoint_updated():
349+
current_tornasole_state = dict()
350+
current_tornasole_state[TRAINING_RUN] = self.training_run
351+
current_tornasole_state[LATEST_GLOBAL_STEP_SAVED] = self.last_saved_step
352+
current_tornasole_state[LATEST_GLOBAL_STEP_SEEN] = self.step
353+
mode_step = dict()
354+
for (mode, step) in self.mode_steps.items():
355+
mode_step[mode.name] = step
356+
current_tornasole_state[LATEST_MODE_STEP] = mode_step
357+
self.state_store.update_tornasole_state(current_tornasole_state)
358+
316359
def set_mode(self, mode):
317360
# train
318361
if mode in ALLOWED_MODES:
@@ -521,7 +564,6 @@ def __init__(
521564
include_collections=include_collections,
522565
save_all=save_all,
523566
)
524-
self.last_saved_step = None
525567
self.exported_collections = False
526568
self.data_type_name = data_type_name
527569

0 commit comments

Comments
 (0)