Skip to content

fix: Disable debugger when checkpointing is enabled with distributed training #2264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

16 changes: 15 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,21 @@ def _validate_and_set_debugger_configs(self):
):
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
elif not self.debugger_hook_config:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to understand it, so it's possible for self.debugger_hook_config to be None or False. And you're changing it to False intentionally to disable it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of self.debugger_hook_config is None.
The if-statement above checks if the debugger is supported in the region in which it is invoked and if not it should be set to False instead of None in my opinion.

self.debugger_hook_config = None
# set hook config to False if _region_supports_debugger is False
self.debugger_hook_config = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing this variable to False? Is this intentional? If so, why?

Copy link
Contributor Author

@NihalHarish NihalHarish Apr 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intent of the line above is to disable the debug hook config.
False is a more explicit indicator of that than None

Copy link
Contributor

@ajaykarpur ajaykarpur Apr 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is already a falsy value in Python.

It doesn't make sense to change this to a boolean, since the debugger_hook_config variable can hold a DebuggerHookConfig type. You can simply check if debugger_hook_config is None instead of checking if it's False.

Copy link
Contributor Author

@NihalHarish NihalHarish Apr 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation however indicates to customers that they need to pass False to opt-out of debugger.

None is the default value of this variable before we reach this code-block.

Debugger is initialized on behalf of the customer it is None

I'm making sure that when we check the value of post the _validate function call, we're indicative that we want debugger off instead of None which might come to mean that it has not yet been initialized or explicitly set by the customer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right:

elif self.debugger_hook_config is None and fw._region_supports_debugger(
self.sagemaker_session.boto_session.region_name
):
# Set defaults for debugging.
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)

Thank you for clarifying.


# Disable debugger if checkpointing is enabled by the customer
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
if self.instance_count > 1 or (
hasattr(self, "distribution")
and self.distribution is not None # pylint: disable=no-member
):
logger.info(
"SMDebug Does Not Currently Support \
Distributed Training Jobs With Checkpointing Enabled"
)
self.debugger_hook_config = False

def _stage_user_code_in_s3(self):
"""Upload the user training script to s3 and return the location.
Expand Down
7 changes: 1 addition & 6 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from packaging import version

from sagemaker import image_uris, s3, utils
from sagemaker.debugger import DebuggerHookConfig
from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework
import sagemaker.fw_utils as fw
Expand Down Expand Up @@ -347,6 +346,7 @@ def _validate_and_set_debugger_configs(self):

Else, set default HookConfig
"""
super(TensorFlow, self)._validate_and_set_debugger_configs()
ps_enabled = "parameter_server" in self.distribution and self.distribution[
"parameter_server"
].get("enabled", False)
Expand All @@ -358,11 +358,6 @@ def _validate_and_set_debugger_configs(self):
)
self.debugger_hook_config = None
self.debugger_rule_configs = None
elif self.debugger_hook_config is None and fw._region_supports_debugger(
self.sagemaker_session.boto_session.region_name
):
# Set defaults for debugging.
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)

def transformer(
self,
Expand Down
112 changes: 112 additions & 0 deletions tests/integ/test_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
TensorBoardOutputConfig,
)
from sagemaker.mxnet.estimator import MXNet
from sagemaker.pytorch.estimator import PyTorch
from sagemaker.tensorflow.estimator import TensorFlow
from sagemaker.xgboost.estimator import XGBoost
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.retry import retries
from tests.integ.timeout import timeout
Expand Down Expand Up @@ -351,6 +354,115 @@ def test_mxnet_with_debugger_hook_config(
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)


def test_debug_hook_disabled_with_checkpointing(
sagemaker_session,
mxnet_training_latest_version,
mxnet_training_latest_py_version,
cpu_instance_type,
):
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
s3_output_path = os.path.join(
"s3://", sagemaker_session.default_bucket(), str(uuid.uuid4())
)
debugger_hook_config = DebuggerHookConfig(
s3_output_path=os.path.join(s3_output_path, "tensors")
)

script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")

# Estimator with checkpointing enabled
mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
framework_version=mxnet_training_latest_version,
py_version=mxnet_training_latest_py_version,
instance_count=1,
instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
debugger_hook_config=debugger_hook_config,
checkpoint_local_path="/opt/ml/checkpoints",
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
)
mx._prepare_for_training()

# Debug Hook should be enabled
assert mx.debugger_hook_config is not None

# Estimator with checkpointing enabled and Instance Count>1
mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
framework_version=mxnet_training_latest_version,
py_version=mxnet_training_latest_py_version,
instance_count=2,
instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
debugger_hook_config=debugger_hook_config,
checkpoint_local_path="/opt/ml/checkpoints",
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
)
mx._prepare_for_training()
# Debug Hook should be disabled
assert mx.debugger_hook_config is False

# Estimator with checkpointing enabled and SMDataParallel Enabled
pt = PyTorch(
base_job_name="pytorch-smdataparallel-mnist",
entry_point=script_path,
role="SageMakerRole",
framework_version="1.8.0",
py_version="py36",
instance_count=1,
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
instance_type="ml.p3.16xlarge",
sagemaker_session=sagemaker_session,
# Training using SMDataParallel Distributed Training Framework
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
checkpoint_local_path="/opt/ml/checkpoints",
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
)
pt._prepare_for_training()
# Debug Hook should be disabled
assert pt.debugger_hook_config is False

# Estimator with checkpointing enabled and SMModelParallel Enabled
tf = TensorFlow(
base_job_name="tf-smdataparallel-mnist",
entry_point=script_path,
role="SageMakerRole",
framework_version="2.4.1",
py_version="py36",
instance_count=1,
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
instance_type="ml.p3.16xlarge",
sagemaker_session=sagemaker_session,
# Training using SMDataParallel Distributed Training Framework
distribution={"smdistributed": {"modelparallel": {"enabled": True}}},
checkpoint_local_path="/opt/ml/checkpoints",
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
)
tf._prepare_for_training()
# Debug Hook should be disabled
assert tf.debugger_hook_config is False

# Estimator with checkpointing enabled with Xgboost Estimator
xg = XGBoost(
base_job_name="test_xgboost",
entry_point=script_path,
role="SageMakerRole",
framework_version="1.2-1",
py_version="py3",
instance_count=2,
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
instance_type="ml.p3.16xlarge",
sagemaker_session=sagemaker_session,
# Training using SMDataParallel Distributed Training Framework
)
xg._prepare_for_training()
# Debug Hook should be enabled
assert xg.debugger_hook_config is not None


def test_mxnet_with_rules_and_debugger_hook_config(
sagemaker_session,
mxnet_training_latest_version,
Expand Down