Skip to content

Commit 336475f

Browse files
committed
disable debugger when checkpointing is enabled
1 parent 2e694a7 commit 336475f

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,11 @@ def __init__(
21472147
self._hyperparameters = hyperparameters or {}
21482148
self.checkpoint_s3_uri = checkpoint_s3_uri
21492149
self.checkpoint_local_path = checkpoint_local_path
2150+
2151+
# Disable debugger if checkpointing is enabled by the customer
2152+
self.debugger_hook_config = \
2153+
self.debugger_hook_config if checkpoint_s3_uri is None \
2154+
and checkpoint_local_path is None else False
21502155
self.enable_sagemaker_metrics = enable_sagemaker_metrics
21512156

21522157
def _prepare_for_training(self, job_name=None):

tests/integ/test_debugger.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,50 @@ def test_mxnet_with_debugger_hook_config(
351351
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
352352

353353

354+
def test_debug_hook_disabled_with_checkpointing(
355+
sagemaker_session,
356+
mxnet_training_latest_version,
357+
mxnet_training_latest_py_version,
358+
cpu_instance_type,
359+
):
360+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
361+
s3_output_path = os.path.join("s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()))
362+
debugger_hook_config = DebuggerHookConfig(
363+
s3_output_path=os.path.join(
364+
s3_output_path, "tensors"
365+
)
366+
)
367+
368+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
369+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
370+
371+
mx = MXNet(
372+
entry_point=script_path,
373+
role="SageMakerRole",
374+
framework_version=mxnet_training_latest_version,
375+
py_version=mxnet_training_latest_py_version,
376+
instance_count=1,
377+
instance_type=cpu_instance_type,
378+
sagemaker_session=sagemaker_session,
379+
debugger_hook_config=debugger_hook_config,
380+
checkpoint_local_path="/opt/ml/checkpoints",
381+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints")
382+
383+
)
384+
385+
train_input = mx.sagemaker_session.upload_data(
386+
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
387+
)
388+
test_input = mx.sagemaker_session.upload_data(
389+
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
390+
)
391+
392+
mx.fit({"train": train_input, "test": test_input})
393+
394+
job_description = mx.latest_training_job.describe()
395+
assert "DebugHookConfig" not in job_description
396+
397+
354398
def test_mxnet_with_rules_and_debugger_hook_config(
355399
sagemaker_session,
356400
mxnet_training_latest_version,

0 commit comments

Comments
 (0)