Skip to content

feat: Add support for actions in debugger rules. #2047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def read_version():
"numpy>=1.9.0",
"protobuf>=3.1",
"protobuf3-to-dict>=0.1.5",
"smdebug_rulesconfig==1.0.0",
"smdebug_rulesconfig==1.0.1",
"importlib-metadata>=1.4.0",
"packaging>=20.0",
]
Expand Down
28 changes: 28 additions & 0 deletions src/sagemaker/debugger/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
volume_size_in_gb,
rule_parameters,
collections_to_save,
actions=None,
):
"""Configure the debugging rules using the following classmethods.

Expand All @@ -170,6 +171,7 @@ def __init__(
rule_parameters,
)
self.collection_configs = collections_to_save
self.actions = actions

@classmethod
def sagemaker(
Expand All @@ -181,6 +183,7 @@ def sagemaker(
other_trials_s3_input_paths=None,
rule_parameters=None,
collections_to_save=None,
actions=None,
):
"""Initialize a ``Rule`` object for a *built-in* debugging rule.

Expand Down Expand Up @@ -268,6 +271,9 @@ def sagemaker(
"""
)

if actions is not None and not rule_configs.is_valid_action_object(actions):
raise RuntimeError("""`actions` must be of type `Action` or `ActionList`!""")

if other_trials_s3_input_paths is not None:
for index, s3_input_path in enumerate(other_trials_s3_input_paths):
merged_rule_params["other_trial_{}".format(str(index))] = s3_input_path
Expand Down Expand Up @@ -298,6 +304,7 @@ def sagemaker(
volume_size_in_gb=None,
rule_parameters=merged_rule_params,
collections_to_save=collections_to_save or base_config_collections,
actions=actions,
)

@classmethod
Expand All @@ -314,6 +321,7 @@ def custom(
other_trials_s3_input_paths=None,
rule_parameters=None,
collections_to_save=None,
actions=None,
):
"""Initialize a ``Rule`` object for a *custom* debugging rule.

Expand Down Expand Up @@ -352,6 +360,9 @@ def custom(
:class:`~sagemaker.debugger.Rule`: The instance of the custom rule.

"""
if actions is not None and not rule_configs.is_valid_action_object(actions):
raise RuntimeError("""`actions` must be of type `Action` or `ActionList`!""")

merged_rule_params = cls._set_rule_parameters(
source, rule_to_invoke, other_trials_s3_input_paths, rule_parameters
)
Expand All @@ -365,8 +376,25 @@ def custom(
volume_size_in_gb=volume_size_in_gb,
rule_parameters=merged_rule_params,
collections_to_save=collections_to_save or [],
actions=actions,
)

def prepare_actions(self, training_job_name):
"""Prepare actions for Debugger Rule.

Args:
training_job_name (str): The training job name. To be set as the default training job
prefix for the StopTraining action if it is specified.
"""
if self.actions is None:
# user cannot manually specify action_json in rule_parameters for actions.
self.rule_parameters.pop("action_json", None)
return

self.actions.update_training_job_prefix_if_not_specified(training_job_name)
action_params = {"action_json": self.actions.serialize()}
self.rule_parameters.update(action_params)

@staticmethod
def _set_rule_parameters(source, rule_to_invoke, other_trials_s3_input_paths, rule_parameters):
"""Set rule parameters for Debugger Rule.
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def _prepare_debugger_rules(self):
for rule in self.debugger_rules:
self._set_default_rule_config(rule)
self._set_source_s3_uri(rule)
rule.prepare_actions(self._current_job_name)
debugger_rule_configs.append(rule.to_debugger_rule_config_dict())
return debugger_rule_configs

Expand Down
134 changes: 133 additions & 1 deletion tests/integ/test_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@
# TODO-reinvent-2019: test get_debugger_artifacts_path and get_tensorboard_artifacts_path


@pytest.fixture
def actions():
return rule_configs.ActionList(
rule_configs.StopTraining(),
rule_configs.Email("[email protected]"),
rule_configs.SMS("+01234567890"),
)


def test_mxnet_with_rules(
sagemaker_session,
mxnet_training_latest_version,
Expand Down Expand Up @@ -125,6 +134,74 @@ def test_mxnet_with_rules(
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)


def test_mxnet_with_rules_and_actions(
sagemaker_session,
mxnet_training_latest_version,
mxnet_training_latest_py_version,
cpu_instance_type,
actions,
):
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
rules = [
Rule.sagemaker(rule_configs.vanishing_gradient(), actions=actions),
Rule.sagemaker(
base_config=rule_configs.all_zero(),
rule_parameters={"tensor_regex": ".*"},
actions=actions,
),
Rule.sagemaker(rule_configs.loss_not_decreasing(), actions=actions),
]

script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
data_path = os.path.join(DATA_DIR, "mxnet_mnist")

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
framework_version=mxnet_training_latest_version,
py_version=mxnet_training_latest_py_version,
instance_count=1,
instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
rules=rules,
)

train_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
)
test_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
)

mx.fit({"train": train_input, "test": test_input})

job_description = mx.latest_training_job.describe()

for index, rule in enumerate(rules):
assert (
job_description["DebugRuleConfigurations"][index]["RuleConfigurationName"]
== rule.name
)
assert (
job_description["DebugRuleConfigurations"][index]["RuleEvaluatorImage"]
== rule.image_uri
)
assert job_description["DebugRuleConfigurations"][index]["VolumeSizeInGB"] == 0
assert (
job_description["DebugRuleConfigurations"][index]["RuleParameters"][
"rule_to_invoke"
]
== rule.rule_parameters["rule_to_invoke"]
)

assert (
_get_rule_evaluation_statuses(job_description)
== mx.latest_training_job.rule_job_summary()
)

_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)


def test_mxnet_with_custom_rule(
sagemaker_session,
mxnet_training_latest_version,
Expand Down Expand Up @@ -178,6 +255,60 @@ def test_mxnet_with_custom_rule(
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)


def test_mxnet_with_custom_rule_and_actions(
sagemaker_session,
mxnet_training_latest_version,
mxnet_training_latest_py_version,
cpu_instance_type,
actions,
):
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
rules = [_get_custom_rule(sagemaker_session, actions)]

script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
data_path = os.path.join(DATA_DIR, "mxnet_mnist")

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
framework_version=mxnet_training_latest_version,
py_version=mxnet_training_latest_py_version,
instance_count=1,
instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
rules=rules,
)

train_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
)
test_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
)

mx.fit({"train": train_input, "test": test_input})

job_description = mx.latest_training_job.describe()

for index, rule in enumerate(rules):
assert (
job_description["DebugRuleConfigurations"][index]["RuleConfigurationName"]
== rule.name
)
assert (
job_description["DebugRuleConfigurations"][index]["RuleEvaluatorImage"]
== rule.image_uri
)
assert job_description["DebugRuleConfigurations"][index]["VolumeSizeInGB"] == 30

assert (
_get_rule_evaluation_statuses(job_description)
== mx.latest_training_job.rule_job_summary()
)

_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)


def test_mxnet_with_debugger_hook_config(
sagemaker_session,
mxnet_training_latest_version,
Expand Down Expand Up @@ -514,7 +645,7 @@ def _get_rule_evaluation_statuses(job_description):
return debug_rule_eval_statuses + profiler_rule_eval_statuses


def _get_custom_rule(session):
def _get_custom_rule(session, actions=None):
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "my_custom_rule.py")

return Rule.custom(
Expand All @@ -526,6 +657,7 @@ def _get_custom_rule(session):
image_uri=CUSTOM_RULE_REPO_WITH_PLACEHOLDERS.format(
CUSTOM_RULE_CONTAINERS_ACCOUNTS_MAP[session.boto_region_name], session.boto_region_name
),
actions=actions,
)


Expand Down
52 changes: 52 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,58 @@ def test_framework_with_only_debugger_rule(sagemaker_session):
}


def test_framework_with_debugger_rule_and_single_action(sagemaker_session):
stop_training_action = rule_configs.StopTraining()
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
rules=[Rule.sagemaker(rule_configs.stalled_training_rule(), actions=stop_training_action)],
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args["debugger_rule_configs"][0]["RuleParameters"] == {
"rule_to_invoke": "StalledTrainingRule",
"action_json": stop_training_action.serialize(),
}
assert stop_training_action.action_parameters["training_job_prefix"] == f._current_job_name
assert args["debugger_hook_config"] == {
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
"CollectionConfigurations": [],
}


def test_framework_with_debugger_rule_and_multiple_actions(sagemaker_session):
action_list = rule_configs.ActionList(
rule_configs.StopTraining(),
rule_configs.Email("[email protected]"),
rule_configs.SMS("+1234567890"),
)
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
rules=[Rule.sagemaker(rule_configs.stalled_training_rule(), actions=action_list)],
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args["debugger_rule_configs"][0]["RuleParameters"] == {
"rule_to_invoke": "StalledTrainingRule",
"action_json": action_list.serialize(),
}
assert action_list.actions[0].action_parameters["training_job_prefix"] == f._current_job_name
assert args["debugger_hook_config"] == {
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
"CollectionConfigurations": [],
}


def test_framework_with_only_debugger_hook_config(sagemaker_session):
hook_config = DebuggerHookConfig(
s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")]
Expand Down