Skip to content

Commit fd9dd85

Browse files
committed
disable debugger when checkpointing is enabled
1 parent 2e694a7 commit fd9dd85

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2209,7 +2209,20 @@ def _validate_and_set_debugger_configs(self):
22092209
):
22102210
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
22112211
elif not self.debugger_hook_config:
2212-
self.debugger_hook_config = None
2212+
self.debugger_hook_config = False
2213+
2214+
# Disable debugger if checkpointing is enabled by the customer
2215+
_should_disable_debugger = False
2216+
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
2217+
if self.instance_count > 1:
2218+
_should_disable_debugger = True
2219+
if hasattr(self, "distribution") and self.distribution is not None:
2220+
_should_disable_debugger = True
2221+
if _should_disable_debugger:
2222+
logger.info(
2223+
"SMDebug Does Not Currently Support Distributed Training Jobs With Checkpointing Enabled"
2224+
)
2225+
self.debugger_hook_config = False
22132226

22142227
def _stage_user_code_in_s3(self):
22152228
"""Upload the user training script to s3 and return the location.

tests/integ/test_debugger.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TensorBoardOutputConfig,
2525
)
2626
from sagemaker.mxnet.estimator import MXNet
27+
from sagemaker.pytorch.estimator import PyTorch
2728
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2829
from tests.integ.retry import retries
2930
from tests.integ.timeout import timeout
@@ -351,6 +352,77 @@ def test_mxnet_with_debugger_hook_config(
351352
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
352353

353354

355+
def test_debug_hook_disabled_with_checkpointing(
356+
sagemaker_session,
357+
mxnet_training_latest_version,
358+
mxnet_training_latest_py_version,
359+
cpu_instance_type,
360+
):
361+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
362+
s3_output_path = os.path.join(
363+
"s3://", sagemaker_session.default_bucket(), str(uuid.uuid4())
364+
)
365+
debugger_hook_config = DebuggerHookConfig(
366+
s3_output_path=os.path.join(s3_output_path, "tensors")
367+
)
368+
369+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
370+
371+
# Estimator with checkpointing enabled
372+
mx = MXNet(
373+
entry_point=script_path,
374+
role="SageMakerRole",
375+
framework_version=mxnet_training_latest_version,
376+
py_version=mxnet_training_latest_py_version,
377+
instance_count=1,
378+
instance_type=cpu_instance_type,
379+
sagemaker_session=sagemaker_session,
380+
debugger_hook_config=debugger_hook_config,
381+
checkpoint_local_path="/opt/ml/checkpoints",
382+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
383+
)
384+
mx._prepare_for_training()
385+
386+
# Debug Hook should be enabled
387+
assert mx.debugger_hook_config is not None
388+
389+
# Estimator with checkpointing enabled and Instance Count>1
390+
mx = MXNet(
391+
entry_point=script_path,
392+
role="SageMakerRole",
393+
framework_version=mxnet_training_latest_version,
394+
py_version=mxnet_training_latest_py_version,
395+
instance_count=2,
396+
instance_type=cpu_instance_type,
397+
sagemaker_session=sagemaker_session,
398+
debugger_hook_config=debugger_hook_config,
399+
checkpoint_local_path="/opt/ml/checkpoints",
400+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
401+
)
402+
mx._prepare_for_training()
403+
# Debug Hook should be enabled
404+
assert mx.debugger_hook_config is False
405+
406+
# Estimator with checkpointing enabled and Model Parallel Enabled
407+
pt = PyTorch(
408+
base_job_name="pytorch-smdataparallel-mnist",
409+
entry_point=script_path,
410+
role="SageMakerRole",
411+
framework_version="1.8.0",
412+
py_version="py36",
413+
instance_count=1,
414+
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
415+
instance_type="ml.p3.16xlarge",
416+
sagemaker_session=sagemaker_session,
417+
# Training using SMDataParallel Distributed Training Framework
418+
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
419+
debugger_hook_config=False,
420+
)
421+
pt._prepare_for_training()
422+
# Debug Hook should be enabled
423+
assert pt.debugger_hook_config is False
424+
425+
354426
def test_mxnet_with_rules_and_debugger_hook_config(
355427
sagemaker_session,
356428
mxnet_training_latest_version,

0 commit comments

Comments
 (0)