Skip to content

Commit 4ab76aa

Browse files
committed
update TF child function
1 parent 582b64d commit 4ab76aa

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2219,6 +2219,7 @@ def _validate_and_set_debugger_configs(self):
22192219
):
22202220
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
22212221
elif not self.debugger_hook_config:
2222+
# set hook config to False if _region_supports_debugger is False
22222223
self.debugger_hook_config = False
22232224

22242225
# Disable debugger if checkpointing is enabled by the customer

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from packaging import version
1919

2020
from sagemaker import image_uris, s3, utils
21-
from sagemaker.debugger import DebuggerHookConfig
2221
from sagemaker.deprecations import renamed_kwargs
2322
from sagemaker.estimator import Framework
2423
import sagemaker.fw_utils as fw
@@ -347,6 +346,7 @@ def _validate_and_set_debugger_configs(self):
347346
348347
Else, set default HookConfig
349348
"""
349+
super(TensorFlow, self)._validate_and_set_debugger_configs()
350350
ps_enabled = "parameter_server" in self.distribution and self.distribution[
351351
"parameter_server"
352352
].get("enabled", False)
@@ -358,11 +358,6 @@ def _validate_and_set_debugger_configs(self):
358358
)
359359
self.debugger_hook_config = None
360360
self.debugger_rule_configs = None
361-
elif self.debugger_hook_config is None and fw._region_supports_debugger(
362-
self.sagemaker_session.boto_session.region_name
363-
):
364-
# Set defaults for debugging.
365-
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
366361

367362
def transformer(
368363
self,

tests/integ/test_debugger.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from sagemaker.mxnet.estimator import MXNet
2727
from sagemaker.pytorch.estimator import PyTorch
28+
from sagemaker.tensorflow.estimator import TensorFlow
2829
from sagemaker.xgboost.estimator import XGBoost
2930
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
3031
from tests.integ.retry import retries
@@ -424,6 +425,26 @@ def test_debug_hook_disabled_with_checkpointing(
424425
# Debug Hook should be disabled
425426
assert pt.debugger_hook_config is False
426427

428+
# Estimator with checkpointing enabled and SMModelParallel Enabled
429+
tf = TensorFlow(
430+
base_job_name="tf-smdataparallel-mnist",
431+
entry_point=script_path,
432+
role="SageMakerRole",
433+
framework_version="2.3.1",
434+
py_version="py36",
435+
instance_count=1,
436+
# For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
437+
instance_type="ml.p3.16xlarge",
438+
sagemaker_session=sagemaker_session,
439+
# Training using SMDataParallel Distributed Training Framework
440+
distribution={"smdistributed": {"modelparallel": {"enabled": True}}},
441+
checkpoint_local_path="/opt/ml/checkpoints",
442+
checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
443+
)
444+
tf._prepare_for_training()
445+
# Debug Hook should be disabled
446+
assert tf.debugger_hook_config is False
447+
427448
# Estimator with checkpointing enabled with Xgboost Estimator
428449
xg = XGBoost(
429450
base_job_name="test_xgboost",

0 commit comments

Comments
 (0)