Skip to content

Commit 2b0944e

Browse files
committed
ignore xgboost estimator
1 parent 2fa1856 commit 2b0944e

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

src/sagemaker/estimator.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,15 +2223,16 @@ def _validate_and_set_debugger_configs(self):
22232223

22242224
# Disable debugger if checkpointing is enabled by the customer
22252225
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
2226-
if self.instance_count > 1 \
2227-
or (
2228-
hasattr(self, "distribution") and self.distribution is not None # pylint: disable=no-member
2229-
):
2230-
logger.info(
2231-
"SMDebug Does Not Currently Support \
2232-
Distributed Training Jobs With Checkpointing Enabled"
2233-
)
2234-
self.debugger_hook_config = False
2226+
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
2227+
if self.instance_count > 1 or (
2228+
hasattr(self, "distribution")
2229+
and self.distribution is not None # pylint: disable=no-member
2230+
):
2231+
logger.info(
2232+
"SMDebug Does Not Currently Support \
2233+
Distributed Training Jobs With Checkpointing Enabled"
2234+
)
2235+
self.debugger_hook_config = False
22352236

22362237
def _stage_user_code_in_s3(self):
22372238
"""Upload the user training script to s3 and return the location.

tests/integ/test_debugger.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from sagemaker.mxnet.estimator import MXNet
2727
from sagemaker.pytorch.estimator import PyTorch
28+
from sagemaker.xgboost.estimator import XGBoost
2829
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2930
from tests.integ.retry import retries
3031
from tests.integ.timeout import timeout
@@ -400,7 +401,7 @@ def test_debug_hook_disabled_with_checkpointing(
400401
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
401402
)
402403
mx._prepare_for_training()
403-
# Debug Hook should be enabled
404+
# Debug Hook should be disabled
404405
assert mx.debugger_hook_config is False
405406

406407
# Estimator with checkpointing enabled and Model Parallel Enabled
@@ -416,12 +417,30 @@ def test_debug_hook_disabled_with_checkpointing(
416417
sagemaker_session=sagemaker_session,
417418
# Training using SMDataParallel Distributed Training Framework
418419
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
419-
debugger_hook_config=False,
420+
checkpoint_local_path="/opt/ml/checkpoints",
421+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
420422
)
421423
pt._prepare_for_training()
422-
# Debug Hook should be enabled
424+
# Debug Hook should be disabled
423425
assert pt.debugger_hook_config is False
424426

427+
# Estimator with checkpointing enabled with Xgboost Estimator
428+
xg = XGBoost(
429+
base_job_name="test_xgboost",
430+
entry_point=script_path,
431+
role="SageMakerRole",
432+
framework_version="1.2-1",
433+
py_version="py3",
434+
instance_count=2,
435+
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
436+
instance_type="ml.p3.16xlarge",
437+
sagemaker_session=sagemaker_session,
438+
# Training using SMDataParallel Distributed Training Framework
439+
)
440+
xg._prepare_for_training()
441+
# Debug Hook should be enabled
442+
assert xg.debugger_hook_config is not None
443+
425444

426445
def test_mxnet_with_rules_and_debugger_hook_config(
427446
sagemaker_session,

0 commit comments

Comments
 (0)