Skip to content

Commit 347d6b2

Browse files
committed
refactor + add log + update test
1 parent 336475f commit 347d6b2

File tree

2 files changed

+15
-24
lines changed

2 files changed

+15
-24
lines changed

src/sagemaker/estimator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2148,10 +2148,6 @@ def __init__(
21482148
self.checkpoint_s3_uri = checkpoint_s3_uri
21492149
self.checkpoint_local_path = checkpoint_local_path
21502150

2151-
# Disable debugger if checkpointing is enabled by the customer
2152-
self.debugger_hook_config = \
2153-
self.debugger_hook_config if checkpoint_s3_uri is None \
2154-
and checkpoint_local_path is None else False
21552151
self.enable_sagemaker_metrics = enable_sagemaker_metrics
21562152

21572153
def _prepare_for_training(self, job_name=None):
@@ -2204,7 +2200,6 @@ def _prepare_for_training(self, job_name=None):
22042200
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
22052201
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
22062202
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
2207-
22082203
self._validate_and_set_debugger_configs()
22092204

22102205
def _validate_and_set_debugger_configs(self):
@@ -2214,7 +2209,14 @@ def _validate_and_set_debugger_configs(self):
22142209
):
22152210
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
22162211
elif not self.debugger_hook_config:
2217-
self.debugger_hook_config = None
2212+
self.debugger_hook_config = False
2213+
2214+
# Disable debugger if checkpointing is enabled by the customer
2215+
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
2216+
logger.info(
2217+
"SM Debug Does Not Currently Support Training Jobs With Checkpointing Enabled"
2218+
)
2219+
self.debugger_hook_config = False
22182220

22192221
def _stage_user_code_in_s3(self):
22202222
"""Upload the user training script to s3 and return the location.

tests/integ/test_debugger.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -358,11 +358,11 @@ def test_debug_hook_disabled_with_checkpointing(
358358
cpu_instance_type,
359359
):
360360
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
361-
s3_output_path = os.path.join("s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()))
361+
s3_output_path = os.path.join(
362+
"s3://", sagemaker_session.default_bucket(), str(uuid.uuid4())
363+
)
362364
debugger_hook_config = DebuggerHookConfig(
363-
s3_output_path=os.path.join(
364-
s3_output_path, "tensors"
365-
)
365+
s3_output_path=os.path.join(s3_output_path, "tensors")
366366
)
367367

368368
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
@@ -378,21 +378,10 @@ def test_debug_hook_disabled_with_checkpointing(
378378
sagemaker_session=sagemaker_session,
379379
debugger_hook_config=debugger_hook_config,
380380
checkpoint_local_path="/opt/ml/checkpoints",
381-
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints")
382-
383-
)
384-
385-
train_input = mx.sagemaker_session.upload_data(
386-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
381+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
387382
)
388-
test_input = mx.sagemaker_session.upload_data(
389-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
390-
)
391-
392-
mx.fit({"train": train_input, "test": test_input})
393-
394-
job_description = mx.latest_training_job.describe()
395-
assert "DebugHookConfig" not in job_description
383+
mx._prepare_for_training()
384+
assert mx.debugger_hook_config is False
396385

397386

398387
def test_mxnet_with_rules_and_debugger_hook_config(

0 commit comments

Comments
 (0)