Skip to content

Commit 17ac9a5

Browse files
authored
Merge branch 'master' into checkdir
2 parents 09505eb + 39c6d5a commit 17ac9a5

File tree

5 files changed

+215
-2
lines changed

5 files changed

+215
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def read_version():
3939
"numpy>=1.9.0",
4040
"protobuf>=3.1",
4141
"protobuf3-to-dict>=0.1.5",
42-
"smdebug_rulesconfig==1.0.0",
42+
"smdebug_rulesconfig==1.0.1",
4343
"importlib-metadata>=1.4.0",
4444
"packaging>=20.0",
4545
]

src/sagemaker/debugger/debugger.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(
150150
volume_size_in_gb,
151151
rule_parameters,
152152
collections_to_save,
153+
actions=None,
153154
):
154155
"""Configure the debugging rules using the following classmethods.
155156
@@ -170,6 +171,7 @@ def __init__(
170171
rule_parameters,
171172
)
172173
self.collection_configs = collections_to_save
174+
self.actions = actions
173175

174176
@classmethod
175177
def sagemaker(
@@ -181,6 +183,7 @@ def sagemaker(
181183
other_trials_s3_input_paths=None,
182184
rule_parameters=None,
183185
collections_to_save=None,
186+
actions=None,
184187
):
185188
"""Initialize a ``Rule`` object for a *built-in* debugging rule.
186189
@@ -268,6 +271,9 @@ def sagemaker(
268271
"""
269272
)
270273

274+
if actions is not None and not rule_configs.is_valid_action_object(actions):
275+
raise RuntimeError("""`actions` must be of type `Action` or `ActionList`!""")
276+
271277
if other_trials_s3_input_paths is not None:
272278
for index, s3_input_path in enumerate(other_trials_s3_input_paths):
273279
merged_rule_params["other_trial_{}".format(str(index))] = s3_input_path
@@ -298,6 +304,7 @@ def sagemaker(
298304
volume_size_in_gb=None,
299305
rule_parameters=merged_rule_params,
300306
collections_to_save=collections_to_save or base_config_collections,
307+
actions=actions,
301308
)
302309

303310
@classmethod
@@ -314,6 +321,7 @@ def custom(
314321
other_trials_s3_input_paths=None,
315322
rule_parameters=None,
316323
collections_to_save=None,
324+
actions=None,
317325
):
318326
"""Initialize a ``Rule`` object for a *custom* debugging rule.
319327
@@ -352,6 +360,9 @@ def custom(
352360
:class:`~sagemaker.debugger.Rule`: The instance of the custom rule.
353361
354362
"""
363+
if actions is not None and not rule_configs.is_valid_action_object(actions):
364+
raise RuntimeError("""`actions` must be of type `Action` or `ActionList`!""")
365+
355366
merged_rule_params = cls._set_rule_parameters(
356367
source, rule_to_invoke, other_trials_s3_input_paths, rule_parameters
357368
)
@@ -365,8 +376,25 @@ def custom(
365376
volume_size_in_gb=volume_size_in_gb,
366377
rule_parameters=merged_rule_params,
367378
collections_to_save=collections_to_save or [],
379+
actions=actions,
368380
)
369381

382+
def prepare_actions(self, training_job_name):
383+
"""Prepare actions for Debugger Rule.
384+
385+
Args:
386+
training_job_name (str): The training job name. To be set as the default training job
387+
prefix for the StopTraining action if it is specified.
388+
"""
389+
if self.actions is None:
390+
# user cannot manually specify action_json in rule_parameters for actions.
391+
self.rule_parameters.pop("action_json", None)
392+
return
393+
394+
self.actions.update_training_job_prefix_if_not_specified(training_job_name)
395+
action_params = {"action_json": self.actions.serialize()}
396+
self.rule_parameters.update(action_params)
397+
370398
@staticmethod
371399
def _set_rule_parameters(source, rule_to_invoke, other_trials_s3_input_paths, rule_parameters):
372400
"""Set rule parameters for Debugger Rule.

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def _prepare_debugger_rules(self):
473473
for rule in self.debugger_rules:
474474
self._set_default_rule_config(rule)
475475
self._set_source_s3_uri(rule)
476+
rule.prepare_actions(self._current_job_name)
476477
debugger_rule_configs.append(rule.to_debugger_rule_config_dict())
477478
return debugger_rule_configs
478479

tests/integ/test_debugger.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@
6060
# TODO-reinvent-2019: test get_debugger_artifacts_path and get_tensorboard_artifacts_path
6161

6262

63+
@pytest.fixture
64+
def actions():
65+
return rule_configs.ActionList(
66+
rule_configs.StopTraining(),
67+
rule_configs.Email("[email protected]"),
68+
rule_configs.SMS("+01234567890"),
69+
)
70+
71+
6372
def test_mxnet_with_rules(
6473
sagemaker_session,
6574
mxnet_training_latest_version,
@@ -125,6 +134,74 @@ def test_mxnet_with_rules(
125134
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
126135

127136

137+
def test_mxnet_with_rules_and_actions(
138+
sagemaker_session,
139+
mxnet_training_latest_version,
140+
mxnet_training_latest_py_version,
141+
cpu_instance_type,
142+
actions,
143+
):
144+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
145+
rules = [
146+
Rule.sagemaker(rule_configs.vanishing_gradient(), actions=actions),
147+
Rule.sagemaker(
148+
base_config=rule_configs.all_zero(),
149+
rule_parameters={"tensor_regex": ".*"},
150+
actions=actions,
151+
),
152+
Rule.sagemaker(rule_configs.loss_not_decreasing(), actions=actions),
153+
]
154+
155+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
156+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
157+
158+
mx = MXNet(
159+
entry_point=script_path,
160+
role="SageMakerRole",
161+
framework_version=mxnet_training_latest_version,
162+
py_version=mxnet_training_latest_py_version,
163+
instance_count=1,
164+
instance_type=cpu_instance_type,
165+
sagemaker_session=sagemaker_session,
166+
rules=rules,
167+
)
168+
169+
train_input = mx.sagemaker_session.upload_data(
170+
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
171+
)
172+
test_input = mx.sagemaker_session.upload_data(
173+
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
174+
)
175+
176+
mx.fit({"train": train_input, "test": test_input})
177+
178+
job_description = mx.latest_training_job.describe()
179+
180+
for index, rule in enumerate(rules):
181+
assert (
182+
job_description["DebugRuleConfigurations"][index]["RuleConfigurationName"]
183+
== rule.name
184+
)
185+
assert (
186+
job_description["DebugRuleConfigurations"][index]["RuleEvaluatorImage"]
187+
== rule.image_uri
188+
)
189+
assert job_description["DebugRuleConfigurations"][index]["VolumeSizeInGB"] == 0
190+
assert (
191+
job_description["DebugRuleConfigurations"][index]["RuleParameters"][
192+
"rule_to_invoke"
193+
]
194+
== rule.rule_parameters["rule_to_invoke"]
195+
)
196+
197+
assert (
198+
_get_rule_evaluation_statuses(job_description)
199+
== mx.latest_training_job.rule_job_summary()
200+
)
201+
202+
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
203+
204+
128205
def test_mxnet_with_custom_rule(
129206
sagemaker_session,
130207
mxnet_training_latest_version,
@@ -178,6 +255,60 @@ def test_mxnet_with_custom_rule(
178255
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
179256

180257

258+
def test_mxnet_with_custom_rule_and_actions(
259+
sagemaker_session,
260+
mxnet_training_latest_version,
261+
mxnet_training_latest_py_version,
262+
cpu_instance_type,
263+
actions,
264+
):
265+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
266+
rules = [_get_custom_rule(sagemaker_session, actions)]
267+
268+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
269+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
270+
271+
mx = MXNet(
272+
entry_point=script_path,
273+
role="SageMakerRole",
274+
framework_version=mxnet_training_latest_version,
275+
py_version=mxnet_training_latest_py_version,
276+
instance_count=1,
277+
instance_type=cpu_instance_type,
278+
sagemaker_session=sagemaker_session,
279+
rules=rules,
280+
)
281+
282+
train_input = mx.sagemaker_session.upload_data(
283+
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
284+
)
285+
test_input = mx.sagemaker_session.upload_data(
286+
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
287+
)
288+
289+
mx.fit({"train": train_input, "test": test_input})
290+
291+
job_description = mx.latest_training_job.describe()
292+
293+
for index, rule in enumerate(rules):
294+
assert (
295+
job_description["DebugRuleConfigurations"][index]["RuleConfigurationName"]
296+
== rule.name
297+
)
298+
assert (
299+
job_description["DebugRuleConfigurations"][index]["RuleEvaluatorImage"]
300+
== rule.image_uri
301+
)
302+
assert job_description["DebugRuleConfigurations"][index]["VolumeSizeInGB"] == 30
303+
304+
assert (
305+
_get_rule_evaluation_statuses(job_description)
306+
== mx.latest_training_job.rule_job_summary()
307+
)
308+
309+
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
310+
311+
181312
def test_mxnet_with_debugger_hook_config(
182313
sagemaker_session,
183314
mxnet_training_latest_version,
@@ -514,7 +645,7 @@ def _get_rule_evaluation_statuses(job_description):
514645
return debug_rule_eval_statuses + profiler_rule_eval_statuses
515646

516647

517-
def _get_custom_rule(session):
648+
def _get_custom_rule(session, actions=None):
518649
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "my_custom_rule.py")
519650

520651
return Rule.custom(
@@ -526,6 +657,7 @@ def _get_custom_rule(session):
526657
image_uri=CUSTOM_RULE_REPO_WITH_PLACEHOLDERS.format(
527658
CUSTOM_RULE_CONTAINERS_ACCOUNTS_MAP[session.boto_region_name], session.boto_region_name
528659
),
660+
actions=actions,
529661
)
530662

531663

tests/unit/test_estimator.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,58 @@ def test_framework_with_only_debugger_rule(sagemaker_session):
385385
}
386386

387387

388+
def test_framework_with_debugger_rule_and_single_action(sagemaker_session):
389+
stop_training_action = rule_configs.StopTraining()
390+
f = DummyFramework(
391+
entry_point=SCRIPT_PATH,
392+
role=ROLE,
393+
sagemaker_session=sagemaker_session,
394+
instance_count=INSTANCE_COUNT,
395+
instance_type=INSTANCE_TYPE,
396+
rules=[Rule.sagemaker(rule_configs.stalled_training_rule(), actions=stop_training_action)],
397+
)
398+
f.fit("s3://mydata")
399+
sagemaker_session.train.assert_called_once()
400+
_, args = sagemaker_session.train.call_args
401+
assert args["debugger_rule_configs"][0]["RuleParameters"] == {
402+
"rule_to_invoke": "StalledTrainingRule",
403+
"action_json": stop_training_action.serialize(),
404+
}
405+
assert stop_training_action.action_parameters["training_job_prefix"] == f._current_job_name
406+
assert args["debugger_hook_config"] == {
407+
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
408+
"CollectionConfigurations": [],
409+
}
410+
411+
412+
def test_framework_with_debugger_rule_and_multiple_actions(sagemaker_session):
413+
action_list = rule_configs.ActionList(
414+
rule_configs.StopTraining(),
415+
rule_configs.Email("[email protected]"),
416+
rule_configs.SMS("+1234567890"),
417+
)
418+
f = DummyFramework(
419+
entry_point=SCRIPT_PATH,
420+
role=ROLE,
421+
sagemaker_session=sagemaker_session,
422+
instance_count=INSTANCE_COUNT,
423+
instance_type=INSTANCE_TYPE,
424+
rules=[Rule.sagemaker(rule_configs.stalled_training_rule(), actions=action_list)],
425+
)
426+
f.fit("s3://mydata")
427+
sagemaker_session.train.assert_called_once()
428+
_, args = sagemaker_session.train.call_args
429+
assert args["debugger_rule_configs"][0]["RuleParameters"] == {
430+
"rule_to_invoke": "StalledTrainingRule",
431+
"action_json": action_list.serialize(),
432+
}
433+
assert action_list.actions[0].action_parameters["training_job_prefix"] == f._current_job_name
434+
assert args["debugger_hook_config"] == {
435+
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
436+
"CollectionConfigurations": [],
437+
}
438+
439+
388440
def test_framework_with_only_debugger_hook_config(sagemaker_session):
389441
hook_config = DebuggerHookConfig(
390442
s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")]

0 commit comments

Comments
 (0)