-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: Disable debugger when checkpointing is enabled with distributed training #2264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
fd9dd85
58d6913
651b662
2fa1856
2b0944e
582b64d
4ab76aa
c9cd18e
3b2338d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2219,7 +2219,20 @@ def _validate_and_set_debugger_configs(self): | |||||||||||
): | ||||||||||||
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path) | ||||||||||||
elif not self.debugger_hook_config: | ||||||||||||
self.debugger_hook_config = None | ||||||||||||
self.debugger_hook_config = False | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we changing this variable to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The intent of the line above is to disable the debug hook config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It doesn't make sense to change this to a boolean, since the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Documentation however indicates to customers that they need to pass
Debugger is initialized on behalf of the customer it is I'm making sure that when we check the value of post the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you're right: sagemaker-python-sdk/src/sagemaker/tensorflow/estimator.py Lines 361 to 365 in b66cb98
Thank you for clarifying. |
||||||||||||
|
||||||||||||
# Disable debugger if checkpointing is enabled by the customer | ||||||||||||
_should_disable_debugger = False | ||||||||||||
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config: | ||||||||||||
if self.instance_count > 1: | ||||||||||||
_should_disable_debugger = True | ||||||||||||
if hasattr(self, "distribution") and self.distribution is not None: | ||||||||||||
NihalHarish marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
_should_disable_debugger = True | ||||||||||||
if _should_disable_debugger: | ||||||||||||
logger.info( | ||||||||||||
"SMDebug Does Not Currently Support Distributed Training Jobs With Checkpointing Enabled" | ||||||||||||
) | ||||||||||||
self.debugger_hook_config = False | ||||||||||||
|
||||||||||||
def _stage_user_code_in_s3(self): | ||||||||||||
"""Upload the user training script to s3 and return the location. | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
TensorBoardOutputConfig, | ||
) | ||
from sagemaker.mxnet.estimator import MXNet | ||
from sagemaker.pytorch.estimator import PyTorch | ||
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES | ||
from tests.integ.retry import retries | ||
from tests.integ.timeout import timeout | ||
|
@@ -351,6 +352,77 @@ def test_mxnet_with_debugger_hook_config( | |
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job) | ||
|
||
|
||
def test_debug_hook_disabled_with_checkpointing( | ||
sagemaker_session, | ||
mxnet_training_latest_version, | ||
mxnet_training_latest_py_version, | ||
cpu_instance_type, | ||
): | ||
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): | ||
s3_output_path = os.path.join( | ||
"s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()) | ||
) | ||
debugger_hook_config = DebuggerHookConfig( | ||
s3_output_path=os.path.join(s3_output_path, "tensors") | ||
) | ||
|
||
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py") | ||
|
||
# Estimator with checkpointing enabled | ||
mx = MXNet( | ||
entry_point=script_path, | ||
role="SageMakerRole", | ||
framework_version=mxnet_training_latest_version, | ||
py_version=mxnet_training_latest_py_version, | ||
instance_count=1, | ||
instance_type=cpu_instance_type, | ||
sagemaker_session=sagemaker_session, | ||
debugger_hook_config=debugger_hook_config, | ||
checkpoint_local_path="/opt/ml/checkpoints", | ||
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"), | ||
) | ||
mx._prepare_for_training() | ||
|
||
# Debug Hook should be enabled | ||
assert mx.debugger_hook_config is not None | ||
|
||
# Estimator with checkpointing enabled and Instance Count>1 | ||
mx = MXNet( | ||
entry_point=script_path, | ||
role="SageMakerRole", | ||
framework_version=mxnet_training_latest_version, | ||
py_version=mxnet_training_latest_py_version, | ||
instance_count=2, | ||
instance_type=cpu_instance_type, | ||
sagemaker_session=sagemaker_session, | ||
debugger_hook_config=debugger_hook_config, | ||
checkpoint_local_path="/opt/ml/checkpoints", | ||
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"), | ||
) | ||
mx._prepare_for_training() | ||
# Debug Hook should be enabled | ||
assert mx.debugger_hook_config is False | ||
|
||
# Estimator with checkpointing enabled and Model Parallel Enabled | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: comment is wrong, this is Data Parallel There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch, fixed. |
||
pt = PyTorch( | ||
base_job_name="pytorch-smdataparallel-mnist", | ||
entry_point=script_path, | ||
role="SageMakerRole", | ||
framework_version="1.8.0", | ||
py_version="py36", | ||
instance_count=1, | ||
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge | ||
instance_type="ml.p3.16xlarge", | ||
sagemaker_session=sagemaker_session, | ||
# Training using SMDataParallel Distributed Training Framework | ||
distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, | ||
debugger_hook_config=False, | ||
) | ||
pt._prepare_for_training() | ||
# Debug Hook should be enabled | ||
assert pt.debugger_hook_config is False | ||
|
||
|
||
def test_mxnet_with_rules_and_debugger_hook_config( | ||
sagemaker_session, | ||
mxnet_training_latest_version, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to understand it, so it's possible for
self.debugger_hook_config
to beNone
orFalse
. And you're changing it toFalse
intentionally to disable it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value of
self.debugger_hook_config
is None.The if-statement above checks if the debugger is supported in the region in which it is invoked and if not it should be set to
False
instead ofNone
in my opinion.