File tree Expand file tree Collapse file tree 4 files changed +10
-1
lines changed
tests/unit/sagemaker/jumpstart/estimator Expand file tree Collapse file tree 4 files changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -106,6 +106,7 @@ def __init__(
106
106
container_entry_point : Optional [List [str ]] = None ,
107
107
container_arguments : Optional [List [str ]] = None ,
108
108
disable_output_compression : Optional [bool ] = None ,
109
+ enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
109
110
):
110
111
"""Initializes a ``JumpStartEstimator``.
111
112
@@ -495,6 +496,8 @@ def __init__(
495
496
a training job.
496
497
disable_output_compression (Optional[bool]): When set to true, Model is uploaded
497
498
to Amazon S3 without compression after training finishes.
499
+ enable_remote_debug (bool or PipelineVariable): Optional.
500
+ Specifies whether RemoteDebug is enabled for the training job
498
501
499
502
Raises:
500
503
ValueError: If the model ID is not recognized by JumpStart.
@@ -569,6 +572,7 @@ def _is_valid_model_id_hook():
569
572
container_arguments = container_arguments ,
570
573
disable_output_compression = disable_output_compression ,
571
574
enable_infra_check = enable_infra_check ,
575
+ enable_remote_debug = enable_remote_debug ,
572
576
)
573
577
574
578
self .model_id = estimator_init_kwargs .model_id
Original file line number Diff line number Diff line change @@ -127,6 +127,7 @@ def get_init_kwargs(
127
127
container_arguments : Optional [List [str ]] = None ,
128
128
disable_output_compression : Optional [bool ] = None ,
129
129
enable_infra_check : Optional [Union [bool , PipelineVariable ]] = None ,
130
+ enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
130
131
) -> JumpStartEstimatorInitKwargs :
131
132
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
132
133
@@ -183,6 +184,7 @@ def get_init_kwargs(
183
184
container_arguments = container_arguments ,
184
185
disable_output_compression = disable_output_compression ,
185
186
enable_infra_check = enable_infra_check ,
187
+ enable_remote_debug = enable_remote_debug ,
186
188
)
187
189
188
190
estimator_init_kwargs = _add_model_version_to_kwargs (estimator_init_kwargs )
Original file line number Diff line number Diff line change @@ -1280,6 +1280,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
1280
1280
"container_arguments" ,
1281
1281
"disable_output_compression" ,
1282
1282
"enable_infra_check" ,
1283
+ "enable_remote_debug" ,
1283
1284
]
1284
1285
1285
1286
SERIALIZATION_EXCLUSION_SET = {
@@ -1344,6 +1345,7 @@ def __init__(
1344
1345
container_arguments : Optional [List [str ]] = None ,
1345
1346
disable_output_compression : Optional [bool ] = None ,
1346
1347
enable_infra_check : Optional [Union [bool , PipelineVariable ]] = None ,
1348
+ enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
1347
1349
) -> None :
1348
1350
"""Instantiates JumpStartEstimatorInitKwargs object."""
1349
1351
@@ -1401,6 +1403,7 @@ def __init__(
1401
1403
self .container_arguments = container_arguments
1402
1404
self .disable_output_compression = disable_output_compression
1403
1405
self .enable_infra_check = enable_infra_check
1406
+ self .enable_remote_debug = enable_remote_debug
1404
1407
1405
1408
1406
1409
class JumpStartEstimatorFitKwargs (JumpStartKwargs ):
Original file line number Diff line number Diff line change @@ -977,7 +977,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
977
977
Please add the new argument to the skip set below,
978
978
and reach out to JumpStart team."""
979
979
980
- init_args_to_skip : Set [str ] = set (["kwargs" , "enable_remote_debug" ])
980
+ init_args_to_skip : Set [str ] = set (["kwargs" ])
981
981
fit_args_to_skip : Set [str ] = set ()
982
982
deploy_args_to_skip : Set [str ] = set (["kwargs" ])
983
983
You can’t perform that action at this time.
0 commit comments