Skip to content

Commit d1dd8ec

Browse files
evakravixixinyu-aws
authored andcommitted
chore: add jumpstart support for remote debug
1 parent fedd74e commit d1dd8ec

File tree

4 files changed

+10
-1
lines changed

4 files changed

+10
-1
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __init__(
106106
container_entry_point: Optional[List[str]] = None,
107107
container_arguments: Optional[List[str]] = None,
108108
disable_output_compression: Optional[bool] = None,
109+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
109110
):
110111
"""Initializes a ``JumpStartEstimator``.
111112
@@ -495,6 +496,8 @@ def __init__(
495496
a training job.
496497
disable_output_compression (Optional[bool]): When set to true, Model is uploaded
497498
to Amazon S3 without compression after training finishes.
499+
enable_remote_debug (bool or PipelineVariable): Optional.
500+
Specifies whether RemoteDebug is enabled for the training job
498501
499502
Raises:
500503
ValueError: If the model ID is not recognized by JumpStart.
@@ -569,6 +572,7 @@ def _is_valid_model_id_hook():
569572
container_arguments=container_arguments,
570573
disable_output_compression=disable_output_compression,
571574
enable_infra_check=enable_infra_check,
575+
enable_remote_debug=enable_remote_debug,
572576
)
573577

574578
self.model_id = estimator_init_kwargs.model_id

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def get_init_kwargs(
127127
container_arguments: Optional[List[str]] = None,
128128
disable_output_compression: Optional[bool] = None,
129129
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
130+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
130131
) -> JumpStartEstimatorInitKwargs:
131132
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
132133

@@ -183,6 +184,7 @@ def get_init_kwargs(
183184
container_arguments=container_arguments,
184185
disable_output_compression=disable_output_compression,
185186
enable_infra_check=enable_infra_check,
187+
enable_remote_debug=enable_remote_debug,
186188
)
187189

188190
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
12801280
"container_arguments",
12811281
"disable_output_compression",
12821282
"enable_infra_check",
1283+
"enable_remote_debug",
12831284
]
12841285

12851286
SERIALIZATION_EXCLUSION_SET = {
@@ -1344,6 +1345,7 @@ def __init__(
13441345
container_arguments: Optional[List[str]] = None,
13451346
disable_output_compression: Optional[bool] = None,
13461347
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
1348+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
13471349
) -> None:
13481350
"""Instantiates JumpStartEstimatorInitKwargs object."""
13491351

@@ -1401,6 +1403,7 @@ def __init__(
14011403
self.container_arguments = container_arguments
14021404
self.disable_output_compression = disable_output_compression
14031405
self.enable_infra_check = enable_infra_check
1406+
self.enable_remote_debug = enable_remote_debug
14041407

14051408

14061409
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
977977
Please add the new argument to the skip set below,
978978
and reach out to JumpStart team."""
979979

980-
init_args_to_skip: Set[str] = set(["kwargs", "enable_remote_debug"])
980+
init_args_to_skip: Set[str] = set(["kwargs"])
981981
fit_args_to_skip: Set[str] = set()
982982
deploy_args_to_skip: Set[str] = set(["kwargs"])
983983

0 commit comments

Comments
 (0)