Skip to content

Commit d106d52

Browse files
committed
feature: support remote debug for sagemaker training job
1 parent 8c2012b commit d106d52

File tree

4 files changed

+150
-3
lines changed

4 files changed

+150
-3
lines changed

src/sagemaker/estimator.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def __init__(
178178
container_entry_point: Optional[List[str]] = None,
179179
container_arguments: Optional[List[str]] = None,
180180
disable_output_compression: bool = False,
181+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
181182
**kwargs,
182183
):
183184
"""Initialize an ``EstimatorBase`` instance.
@@ -540,6 +541,8 @@ def __init__(
540541
to Amazon S3 without compression after training finishes.
541542
enable_infra_check (bool or PipelineVariable): Optional.
542543
Specifies whether it is running Sagemaker built-in infra check jobs.
544+
enable_remote_debug (bool or PipelineVariable): Optional.
545+
Specifies whether RemoteDebug is enabled for the training job
543546
"""
544547
instance_count = renamed_kwargs(
545548
"train_instance_count", "instance_count", instance_count, kwargs
@@ -766,6 +769,8 @@ def __init__(
766769

767770
self.tensorboard_app = TensorBoardApp(region=self.sagemaker_session.boto_region_name)
768771

772+
self.enable_remote_debug = enable_remote_debug
773+
769774
@abstractmethod
770775
def training_image_uri(self):
771776
"""Return the Docker image to use for training.
@@ -1947,6 +1952,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19471952
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
19481953
if max_wait:
19491954
init_params["max_wait"] = max_wait
1955+
1956+
if "RemoteDebugConfig" in job_details:
1957+
init_params["enable_remote_debug"] = job_details["RemoteDebugConfig"].get(
1958+
"EnableRemoteDebug"
1959+
)
19501960
return init_params
19511961

19521962
def _get_instance_type(self):
@@ -2281,6 +2291,22 @@ def update_profiler(
22812291

22822292
_TrainingJob.update(self, profiler_rule_configs, profiler_config_request_dict)
22832293

2294+
def update_remote_debug(self, enable_remote_debug: bool):
2295+
"""Update training jobs to enable remote debug.
2296+
2297+
This method updates the ``enable_remote_debug`` parameter
2298+
and enables or disables remote debug for a training job
2299+
2300+
Args:
2301+
enable_remote_debug (bool):
2302+
Specifies whether RemoteDebug is to be enabled for the training job
2303+
"""
2304+
self._ensure_latest_training_job()
2305+
self.enable_remote_debug = enable_remote_debug
2306+
_TrainingJob.update(
2307+
self, remote_debug_config={"EnableRemoteDebug": self.enable_remote_debug}
2308+
)
2309+
22842310
def get_app_url(
22852311
self,
22862312
app_type,
@@ -2509,6 +2535,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25092535
if estimator.profiler_config:
25102536
train_args["profiler_config"] = estimator.profiler_config._to_request_dict()
25112537

2538+
if estimator.enable_remote_debug is not None:
2539+
train_args["remote_debug_config"] = {"EnableRemoteDebug": estimator.enable_remote_debug}
2540+
25122541
return train_args
25132542

25142543
@classmethod
@@ -2538,7 +2567,12 @@ def _is_local_channel(cls, input_uri):
25382567

25392568
@classmethod
25402569
def update(
2541-
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
2570+
cls,
2571+
estimator,
2572+
profiler_rule_configs=None,
2573+
profiler_config=None,
2574+
resource_config=None,
2575+
remote_debug_config=None,
25422576
):
25432577
"""Update a running Amazon SageMaker training job.
25442578
@@ -2551,20 +2585,31 @@ def update(
25512585
resource_config (dict): Configuration of the resources for the training job. You can
25522586
update the keep-alive period if the warm pool status is `Available`. No other fields
25532587
can be updated. (default: None).
2588+
remote_debug_config (dict): Configuration for RemoteDebug. (default: ``None``)
2589+
The dict can contain 'EnableRemoteDebug'(bool).
2590+
For example,
2591+
2592+
.. code:: python
2593+
2594+
remote_debug_config = {
2595+
"EnableRemoteDebug": True,
2596+
} (default: None).
25542597
25552598
Returns:
25562599
sagemaker.estimator._TrainingJob: Constructed object that captures
25572600
all information about the updated training job.
25582601
"""
25592602
update_args = cls._get_update_args(
2560-
estimator, profiler_rule_configs, profiler_config, resource_config
2603+
estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config
25612604
)
25622605
estimator.sagemaker_session.update_training_job(**update_args)
25632606

25642607
return estimator.latest_training_job
25652608

25662609
@classmethod
2567-
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, resource_config):
2610+
def _get_update_args(
2611+
cls, estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config
2612+
):
25682613
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.
25692614
25702615
Args:
@@ -2585,6 +2630,7 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, res
25852630
update_args.update(build_dict("profiler_rule_configs", profiler_rule_configs))
25862631
update_args.update(build_dict("profiler_config", profiler_config))
25872632
update_args.update(build_dict("resource_config", resource_config))
2633+
update_args.update(build_dict("remote_debug_config", remote_debug_config))
25882634

25892635
return update_args
25902636

@@ -2683,6 +2729,7 @@ def __init__(
26832729
container_arguments: Optional[List[str]] = None,
26842730
disable_output_compression: bool = False,
26852731
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
2732+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
26862733
**kwargs,
26872734
):
26882735
"""Initialize an ``Estimator`` instance.
@@ -3044,6 +3091,8 @@ def __init__(
30443091
to Amazon S3 without compression after training finishes.
30453092
enable_infra_check (bool or PipelineVariable): Optional.
30463093
Specifies whether it is running Sagemaker built-in infra check jobs.
3094+
enable_remote_debug (bool or PipelineVariable): Optional.
3095+
Specifies whether RemoteDebug is enabled for the training job
30473096
"""
30483097
self.image_uri = image_uri
30493098
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -3095,6 +3144,7 @@ def __init__(
30953144
container_entry_point=container_entry_point,
30963145
container_arguments=container_arguments,
30973146
disable_output_compression=disable_output_compression,
3147+
enable_remote_debug=enable_remote_debug,
30983148
**kwargs,
30993149
)
31003150

src/sagemaker/session.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,7 @@ def train( # noqa: C901
741741
profiler_config=None,
742742
environment: Optional[Dict[str, str]] = None,
743743
retry_strategy=None,
744+
remote_debug_config=None,
744745
):
745746
"""Create an Amazon SageMaker training job.
746747
@@ -851,6 +852,15 @@ def train( # noqa: C901
851852
configurations.src/sagemaker/lineage/artifact.py:285
852853
profiler_config (dict): Configuration for how profiling information is emitted
853854
with SageMaker Profiler. (default: ``None``).
855+
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
856+
The dict can contain 'EnableRemoteDebug'(bool).
857+
For example,
858+
859+
.. code:: python
860+
861+
remote_debug_config = {
862+
"EnableRemoteDebug": True,
863+
}
854864
environment (dict[str, str]) : Environment variables to be set for
855865
use during training job (default: ``None``)
856866
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -943,6 +953,7 @@ def train( # noqa: C901
943953
enable_sagemaker_metrics=enable_sagemaker_metrics,
944954
profiler_rule_configs=profiler_rule_configs,
945955
profiler_config=inferred_profiler_config,
956+
remote_debug_config=remote_debug_config,
946957
environment=environment,
947958
retry_strategy=retry_strategy,
948959
)
@@ -985,6 +996,7 @@ def _get_train_request( # noqa: C901
985996
enable_sagemaker_metrics=None,
986997
profiler_rule_configs=None,
987998
profiler_config=None,
999+
remote_debug_config=None,
9881000
environment=None,
9891001
retry_strategy=None,
9901002
):
@@ -1096,6 +1108,15 @@ def _get_train_request( # noqa: C901
10961108
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
10971109
profiler_config(dict): Configuration for how profiling information is emitted with
10981110
SageMaker Profiler. (default: ``None``).
1111+
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
1112+
The dict can contain 'EnableRemoteDebug'(bool).
1113+
For example,
1114+
1115+
.. code:: python
1116+
1117+
remote_debug_config = {
1118+
"EnableRemoteDebug": True,
1119+
}
10991120
environment (dict[str, str]) : Environment variables to be set for
11001121
use during training job (default: ``None``)
11011122
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -1199,6 +1220,9 @@ def _get_train_request( # noqa: C901
11991220
if profiler_config is not None:
12001221
train_request["ProfilerConfig"] = profiler_config
12011222

1223+
if remote_debug_config is not None:
1224+
train_request["RemoteDebugConfig"] = remote_debug_config
1225+
12021226
if retry_strategy is not None:
12031227
train_request["RetryStrategy"] = retry_strategy
12041228

@@ -1210,6 +1234,7 @@ def update_training_job(
12101234
profiler_rule_configs=None,
12111235
profiler_config=None,
12121236
resource_config=None,
1237+
remote_debug_config=None,
12131238
):
12141239
"""Calls the UpdateTrainingJob API for the given job name and returns the response.
12151240
@@ -1221,6 +1246,15 @@ def update_training_job(
12211246
resource_config (dict): Configuration of the resources for the training job. You can
12221247
update the keep-alive period if the warm pool status is `Available`. No other fields
12231248
can be updated. (default: ``None``).
1249+
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
1250+
The dict can contain 'EnableRemoteDebug'(bool).
1251+
For example,
1252+
1253+
.. code:: python
1254+
1255+
remote_debug_config = {
1256+
"EnableRemoteDebug": True,
1257+
}
12241258
"""
12251259
# No injections from sagemaker_config because the UpdateTrainingJob API's resource_config
12261260
# object accepts fewer parameters than the CreateTrainingJob API, and none that the
@@ -1233,6 +1267,7 @@ def update_training_job(
12331267
profiler_rule_configs=profiler_rule_configs,
12341268
profiler_config=inferred_profiler_config,
12351269
resource_config=resource_config,
1270+
remote_debug_config=remote_debug_config,
12361271
)
12371272
LOGGER.info("Updating training job with name %s", job_name)
12381273
LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4))
@@ -1244,6 +1279,7 @@ def _get_update_training_job_request(
12441279
profiler_rule_configs=None,
12451280
profiler_config=None,
12461281
resource_config=None,
1282+
remote_debug_config=None,
12471283
):
12481284
"""Constructs a request compatible for updating an Amazon SageMaker training job.
12491285
@@ -1255,6 +1291,15 @@ def _get_update_training_job_request(
12551291
resource_config (dict): Configuration of the resources for the training job. You can
12561292
update the keep-alive period if the warm pool status is `Available`. No other fields
12571293
can be updated. (default: ``None``).
1294+
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
1295+
The dict can contain 'EnableRemoteDebug'(bool).
1296+
For example,
1297+
1298+
.. code:: python
1299+
1300+
remote_debug_config = {
1301+
"EnableRemoteDebug": True,
1302+
}
12581303
12591304
Returns:
12601305
Dict: an update training request dict
@@ -1272,6 +1317,9 @@ def _get_update_training_job_request(
12721317
if resource_config is not None:
12731318
update_training_job_request["ResourceConfig"] = resource_config
12741319

1320+
if remote_debug_config is not None:
1321+
update_training_job_request["RemoteDebugConfig"] = remote_debug_config
1322+
12751323
return update_training_job_request
12761324

12771325
def process(

tests/unit/test_estimator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,6 +1994,43 @@ def test_sagemaker_model_custom_channel_name(sagemaker_session):
19941994
]
19951995

19961996

1997+
def test_framework_with_remote_debug_config(sagemaker_session):
1998+
f = DummyFramework(
1999+
entry_point=SCRIPT_PATH,
2000+
role=ROLE,
2001+
sagemaker_session=sagemaker_session,
2002+
instance_groups=[
2003+
InstanceGroup("group1", "ml.c4.xlarge", 1),
2004+
InstanceGroup("group2", "ml.m4.xlarge", 2),
2005+
],
2006+
enable_remote_debug=True,
2007+
)
2008+
f.fit("s3://mydata")
2009+
sagemaker_session.train.assert_called_once()
2010+
_, args = sagemaker_session.train.call_args
2011+
assert args["remote_debug_config"]["EnableRemoteDebug"]
2012+
2013+
2014+
def test_framework_update_remote_debug(sagemaker_session):
2015+
f = DummyFramework(
2016+
entry_point=SCRIPT_PATH,
2017+
role=ROLE,
2018+
sagemaker_session=sagemaker_session,
2019+
instance_count=INSTANCE_COUNT,
2020+
instance_type=INSTANCE_TYPE,
2021+
enable_remote_debug=True,
2022+
)
2023+
f.fit("s3://mydata")
2024+
f.update_remote_debug(False)
2025+
2026+
sagemaker_session.update_training_job.assert_called_once()
2027+
_, args = sagemaker_session.update_training_job.call_args
2028+
assert args["remote_debug_config"] == {
2029+
"EnableRemoteDebug": False,
2030+
}
2031+
assert len(args) == 2
2032+
2033+
19972034
@patch("time.strftime", return_value=TIMESTAMP)
19982035
def test_custom_code_bucket(time, sagemaker_session):
19992036
code_bucket = "codebucket"

tests/unit/test_session.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,6 +1876,15 @@ def test_update_training_job_with_sagemaker_config_injection(sagemaker_session):
18761876
)
18771877

18781878

1879+
def test_update_training_job_with_remote_debug_config(sagemaker_session):
1880+
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB
1881+
sagemaker_session.update_training_job(
1882+
job_name="MyTestJob", remote_debug_config={"EnableRemoteDebug": False}
1883+
)
1884+
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
1885+
assert not actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"]
1886+
1887+
18791888
def test_train_with_sagemaker_config_injection(sagemaker_session):
18801889
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB
18811890

@@ -2128,6 +2137,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
21282137
}
21292138
CONTAINER_ENTRY_POINT = ["bin/bash", "test.sh"]
21302139
CONTAINER_ARGUMENTS = ["--arg1", "value1", "--arg2", "value2"]
2140+
remote_debug_config = {"EnableRemoteDebug": True}
21312141

21322142
sagemaker_session.train(
21332143
image_uri=IMAGE,
@@ -2152,6 +2162,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
21522162
training_image_config=TRAINING_IMAGE_CONFIG,
21532163
container_entry_point=CONTAINER_ENTRY_POINT,
21542164
container_arguments=CONTAINER_ARGUMENTS,
2165+
remote_debug_config=remote_debug_config,
21552166
)
21562167

21572168
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
@@ -2174,6 +2185,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
21742185
actual_train_args["AlgorithmSpecification"]["ContainerEntrypoint"] == CONTAINER_ENTRY_POINT
21752186
)
21762187
assert actual_train_args["AlgorithmSpecification"]["ContainerArguments"] == CONTAINER_ARGUMENTS
2188+
assert actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"]
21772189

21782190

21792191
def test_create_transform_job_with_sagemaker_config_injection(sagemaker_session):

0 commit comments

Comments
 (0)