Skip to content

Commit 3045259

Browse files
committed
change: Replace update_remote_config with 2 helper methods for enable and disable respectively
1 parent d106d52 commit 3045259

File tree

2 files changed

+65
-16
lines changed

2 files changed

+65
-16
lines changed

src/sagemaker/estimator.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def __init__(
769769

770770
self.tensorboard_app = TensorBoardApp(region=self.sagemaker_session.boto_region_name)
771771

772-
self.enable_remote_debug = enable_remote_debug
772+
self._enable_remote_debug = enable_remote_debug
773773

774774
@abstractmethod
775775
def training_image_uri(self):
@@ -2291,21 +2291,31 @@ def update_profiler(
22912291

22922292
_TrainingJob.update(self, profiler_rule_configs, profiler_config_request_dict)
22932293

2294-
def update_remote_debug(self, enable_remote_debug: bool):
2295-
"""Update training jobs to enable remote debug.
2294+
def get_remote_debug_config(self):
2295+
"""dict: Return the configuration of RemoteDebug"""
2296+
return (
2297+
None
2298+
if self._enable_remote_debug is None
2299+
else {"EnableRemoteDebug": self._enable_remote_debug}
2300+
)
22962301

2297-
This method updates the ``enable_remote_debug`` parameter
2298-
and enables or disables remote debug for a training job
2302+
def enable_remote_debug(self):
2303+
"""Enable remote debug for a training job."""
2304+
self._update_remote_debug(True)
22992305

2300-
Args:
2301-
enable_remote_debug (bool):
2302-
Specifies whether RemoteDebug is to be enabled for the training job
2306+
def disable_remote_debug(self):
2307+
"""Disable remote debug for a training job."""
2308+
self._update_remote_debug(False)
2309+
2310+
def _update_remote_debug(self, enable_remote_debug: bool):
2311+
"""Update to enable or disable remote debug for a training job.
2312+
2313+
This method updates the ``_enable_remote_debug`` parameter
2314+
and enables or disables remote debug for a training job
23032315
"""
23042316
self._ensure_latest_training_job()
2305-
self.enable_remote_debug = enable_remote_debug
2306-
_TrainingJob.update(
2307-
self, remote_debug_config={"EnableRemoteDebug": self.enable_remote_debug}
2308-
)
2317+
_TrainingJob.update(self, remote_debug_config={"EnableRemoteDebug": enable_remote_debug})
2318+
self._enable_remote_debug = enable_remote_debug
23092319

23102320
def get_app_url(
23112321
self,
@@ -2535,8 +2545,8 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25352545
if estimator.profiler_config:
25362546
train_args["profiler_config"] = estimator.profiler_config._to_request_dict()
25372547

2538-
if estimator.enable_remote_debug is not None:
2539-
train_args["remote_debug_config"] = {"EnableRemoteDebug": estimator.enable_remote_debug}
2548+
if estimator.get_remote_debug_config() is not None:
2549+
train_args["remote_debug_config"] = estimator.get_remote_debug_config()
25402550

25412551
return train_args
25422552

tests/unit/test_estimator.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,9 +2009,47 @@ def test_framework_with_remote_debug_config(sagemaker_session):
20092009
sagemaker_session.train.assert_called_once()
20102010
_, args = sagemaker_session.train.call_args
20112011
assert args["remote_debug_config"]["EnableRemoteDebug"]
2012+
assert f.get_remote_debug_config()["EnableRemoteDebug"]
20122013

20132014

2014-
def test_framework_update_remote_debug(sagemaker_session):
2015+
def test_framework_without_remote_debug_config(sagemaker_session):
2016+
f = DummyFramework(
2017+
entry_point=SCRIPT_PATH,
2018+
role=ROLE,
2019+
sagemaker_session=sagemaker_session,
2020+
instance_groups=[
2021+
InstanceGroup("group1", "ml.c4.xlarge", 1),
2022+
InstanceGroup("group2", "ml.m4.xlarge", 2),
2023+
],
2024+
)
2025+
f.fit("s3://mydata")
2026+
sagemaker_session.train.assert_called_once()
2027+
_, args = sagemaker_session.train.call_args
2028+
assert args.get("remote_debug_config") is None
2029+
assert f.get_remote_debug_config() is None
2030+
2031+
2032+
def test_framework_enable_remote_debug(sagemaker_session):
2033+
f = DummyFramework(
2034+
entry_point=SCRIPT_PATH,
2035+
role=ROLE,
2036+
sagemaker_session=sagemaker_session,
2037+
instance_count=INSTANCE_COUNT,
2038+
instance_type=INSTANCE_TYPE,
2039+
)
2040+
f.fit("s3://mydata")
2041+
f.enable_remote_debug()
2042+
2043+
sagemaker_session.update_training_job.assert_called_once()
2044+
_, args = sagemaker_session.update_training_job.call_args
2045+
assert args["remote_debug_config"] == {
2046+
"EnableRemoteDebug": True,
2047+
}
2048+
assert f.get_remote_debug_config()["EnableRemoteDebug"]
2049+
assert len(args) == 2
2050+
2051+
2052+
def test_framework_disable_remote_debug(sagemaker_session):
20152053
f = DummyFramework(
20162054
entry_point=SCRIPT_PATH,
20172055
role=ROLE,
@@ -2021,13 +2059,14 @@ def test_framework_update_remote_debug(sagemaker_session):
20212059
enable_remote_debug=True,
20222060
)
20232061
f.fit("s3://mydata")
2024-
f.update_remote_debug(False)
2062+
f.disable_remote_debug()
20252063

20262064
sagemaker_session.update_training_job.assert_called_once()
20272065
_, args = sagemaker_session.update_training_job.call_args
20282066
assert args["remote_debug_config"] == {
20292067
"EnableRemoteDebug": False,
20302068
}
2069+
assert not f.get_remote_debug_config()["EnableRemoteDebug"]
20312070
assert len(args) == 2
20322071

20332072

0 commit comments

Comments
 (0)