Skip to content

Commit 709bb18

Browse files
authored
fix: Disable debugger when checkpointing is enabled with distributed training (#2264)
1 parent b66cb98 commit 709bb18

File tree

3 files changed

+128
-7
lines changed

3 files changed

+128
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2219,7 +2219,21 @@ def _validate_and_set_debugger_configs(self):
22192219
):
22202220
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
22212221
elif not self.debugger_hook_config:
2222-
self.debugger_hook_config = None
2222+
# set hook config to False if _region_supports_debugger is False
2223+
self.debugger_hook_config = False
2224+
2225+
# Disable debugger if checkpointing is enabled by the customer
2226+
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
2227+
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
2228+
if self.instance_count > 1 or (
2229+
hasattr(self, "distribution")
2230+
and self.distribution is not None # pylint: disable=no-member
2231+
):
2232+
logger.info(
2233+
"SMDebug Does Not Currently Support \
2234+
Distributed Training Jobs With Checkpointing Enabled"
2235+
)
2236+
self.debugger_hook_config = False
22232237

22242238
def _stage_user_code_in_s3(self):
22252239
"""Upload the user training script to s3 and return the location.

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from packaging import version
1919

2020
from sagemaker import image_uris, s3, utils
21-
from sagemaker.debugger import DebuggerHookConfig
2221
from sagemaker.deprecations import renamed_kwargs
2322
from sagemaker.estimator import Framework
2423
import sagemaker.fw_utils as fw
@@ -347,6 +346,7 @@ def _validate_and_set_debugger_configs(self):
347346
348347
Else, set default HookConfig
349348
"""
349+
super(TensorFlow, self)._validate_and_set_debugger_configs()
350350
ps_enabled = "parameter_server" in self.distribution and self.distribution[
351351
"parameter_server"
352352
].get("enabled", False)
@@ -358,11 +358,6 @@ def _validate_and_set_debugger_configs(self):
358358
)
359359
self.debugger_hook_config = None
360360
self.debugger_rule_configs = None
361-
elif self.debugger_hook_config is None and fw._region_supports_debugger(
362-
self.sagemaker_session.boto_session.region_name
363-
):
364-
# Set defaults for debugging.
365-
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
366361

367362
def transformer(
368363
self,

tests/integ/test_debugger.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
TensorBoardOutputConfig,
2525
)
2626
from sagemaker.mxnet.estimator import MXNet
27+
from sagemaker.pytorch.estimator import PyTorch
28+
from sagemaker.tensorflow.estimator import TensorFlow
29+
from sagemaker.xgboost.estimator import XGBoost
2730
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2831
from tests.integ.retry import retries
2932
from tests.integ.timeout import timeout
@@ -351,6 +354,115 @@ def test_mxnet_with_debugger_hook_config(
351354
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
352355

353356

357+
def test_debug_hook_disabled_with_checkpointing(
358+
sagemaker_session,
359+
mxnet_training_latest_version,
360+
mxnet_training_latest_py_version,
361+
cpu_instance_type,
362+
):
363+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
364+
s3_output_path = os.path.join(
365+
"s3://", sagemaker_session.default_bucket(), str(uuid.uuid4())
366+
)
367+
debugger_hook_config = DebuggerHookConfig(
368+
s3_output_path=os.path.join(s3_output_path, "tensors")
369+
)
370+
371+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
372+
373+
# Estimator with checkpointing enabled
374+
mx = MXNet(
375+
entry_point=script_path,
376+
role="SageMakerRole",
377+
framework_version=mxnet_training_latest_version,
378+
py_version=mxnet_training_latest_py_version,
379+
instance_count=1,
380+
instance_type=cpu_instance_type,
381+
sagemaker_session=sagemaker_session,
382+
debugger_hook_config=debugger_hook_config,
383+
checkpoint_local_path="/opt/ml/checkpoints",
384+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
385+
)
386+
mx._prepare_for_training()
387+
388+
# Debug Hook should be enabled
389+
assert mx.debugger_hook_config is not None
390+
391+
# Estimator with checkpointing enabled and Instance Count>1
392+
mx = MXNet(
393+
entry_point=script_path,
394+
role="SageMakerRole",
395+
framework_version=mxnet_training_latest_version,
396+
py_version=mxnet_training_latest_py_version,
397+
instance_count=2,
398+
instance_type=cpu_instance_type,
399+
sagemaker_session=sagemaker_session,
400+
debugger_hook_config=debugger_hook_config,
401+
checkpoint_local_path="/opt/ml/checkpoints",
402+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
403+
)
404+
mx._prepare_for_training()
405+
# Debug Hook should be disabled
406+
assert mx.debugger_hook_config is False
407+
408+
# Estimator with checkpointing enabled and SMDataParallel Enabled
409+
pt = PyTorch(
410+
base_job_name="pytorch-smdataparallel-mnist",
411+
entry_point=script_path,
412+
role="SageMakerRole",
413+
framework_version="1.8.0",
414+
py_version="py36",
415+
instance_count=1,
416+
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
417+
instance_type="ml.p3.16xlarge",
418+
sagemaker_session=sagemaker_session,
419+
# Training using SMDataParallel Distributed Training Framework
420+
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
421+
checkpoint_local_path="/opt/ml/checkpoints",
422+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
423+
)
424+
pt._prepare_for_training()
425+
# Debug Hook should be disabled
426+
assert pt.debugger_hook_config is False
427+
428+
# Estimator with checkpointing enabled and SMModelParallel Enabled
429+
tf = TensorFlow(
430+
base_job_name="tf-smdataparallel-mnist",
431+
entry_point=script_path,
432+
role="SageMakerRole",
433+
framework_version="2.4.1",
434+
py_version="py36",
435+
instance_count=1,
436+
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
437+
instance_type="ml.p3.16xlarge",
438+
sagemaker_session=sagemaker_session,
439+
# Training using SMDataParallel Distributed Training Framework
440+
distribution={"smdistributed": {"modelparallel": {"enabled": True}}},
441+
checkpoint_local_path="/opt/ml/checkpoints",
442+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
443+
)
444+
tf._prepare_for_training()
445+
# Debug Hook should be disabled
446+
assert tf.debugger_hook_config is False
447+
448+
# Estimator with checkpointing enabled with Xgboost Estimator
449+
xg = XGBoost(
450+
base_job_name="test_xgboost",
451+
entry_point=script_path,
452+
role="SageMakerRole",
453+
framework_version="1.2-1",
454+
py_version="py3",
455+
instance_count=2,
456+
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
457+
instance_type="ml.p3.16xlarge",
458+
sagemaker_session=sagemaker_session,
459+
# Training using SMDataParallel Distributed Training Framework
460+
)
461+
xg._prepare_for_training()
462+
# Debug Hook should be enabled
463+
assert xg.debugger_hook_config is not None
464+
465+
354466
def test_mxnet_with_rules_and_debugger_hook_config(
355467
sagemaker_session,
356468
mxnet_training_latest_version,

0 commit comments

Comments
 (0)