Skip to content

Commit 7f33995

Browse files
authored
fix: Missing JumpStart estimator args (#3978)
1 parent 6c1a3a1 commit 7f33995

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def __init__(
104104
instance_groups: Optional[List[InstanceGroup]] = None,
105105
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
106106
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
107+
container_entry_point: Optional[List[str]] = None,
108+
container_arguments: Optional[List[str]] = None,
109+
disable_output_compression: Optional[bool] = None,
107110
):
108111
"""Initializes a ``JumpStartEstimator``.
109112
@@ -484,6 +487,13 @@ def __init__(
484487
private Docker registry where your training image is hosted (Default: None).
485488
When it's set to None, SageMaker will not do authentication before pulling the image
486489
in the private Docker registry. (Default: None).
490+
container_entry_point (Optional[List[str]]): The entrypoint script for a Docker
491+
container used to run a training job. This script takes precedence over
492+
the default train processing instructions.
493+
container_arguments (Optional[List[str]]): The arguments for a container used to run
494+
a training job.
495+
disable_output_compression (Optional[bool]): When set to true, Model is uploaded
496+
to Amazon S3 without compression after training finishes.
487497
488498
Raises:
489499
ValueError: If the model ID is not recognized by JumpStart.
@@ -553,6 +563,9 @@ def _is_valid_model_id_hook():
553563
training_repository_credentials_provider_arn
554564
),
555565
image_uri=image_uri,
566+
container_entry_point=container_entry_point,
567+
container_arguments=container_arguments,
568+
disable_output_compression=disable_output_compression,
556569
)
557570

558571
self.model_id = estimator_init_kwargs.model_id

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def get_init_kwargs(
116116
instance_groups: Optional[List[InstanceGroup]] = None,
117117
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
118118
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
119+
container_entry_point: Optional[List[str]] = None,
120+
container_arguments: Optional[List[str]] = None,
121+
disable_output_compression: Optional[bool] = None,
119122
) -> JumpStartEstimatorInitKwargs:
120123
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
121124

@@ -168,6 +171,9 @@ def get_init_kwargs(
168171
training_repository_access_mode=training_repository_access_mode,
169172
training_repository_credentials_provider_arn=training_repository_credentials_provider_arn,
170173
image_uri=image_uri,
174+
container_entry_point=container_entry_point,
175+
container_arguments=container_arguments,
176+
disable_output_compression=disable_output_compression,
171177
)
172178

173179
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

src/sagemaker/jumpstart/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,9 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
777777
"training_repository_credentials_provider_arn",
778778
"tolerate_deprecated_model",
779779
"tolerate_vulnerable_model",
780+
"container_entry_point",
781+
"container_arguments",
782+
"disable_output_compression",
780783
]
781784

782785
SERIALIZATION_EXCLUSION_SET = {
@@ -837,6 +840,9 @@ def __init__(
837840
training_repository_credentials_provider_arn: Optional[Union[str, Any]] = None,
838841
tolerate_vulnerable_model: Optional[bool] = None,
839842
tolerate_deprecated_model: Optional[bool] = None,
843+
container_entry_point: Optional[List[str]] = None,
844+
container_arguments: Optional[List[str]] = None,
845+
disable_output_compression: Optional[bool] = None,
840846
) -> None:
841847
"""Instantiates JumpStartEstimatorInitKwargs object."""
842848

@@ -890,6 +896,9 @@ def __init__(
890896
)
891897
self.tolerate_vulnerable_model = tolerate_vulnerable_model
892898
self.tolerate_deprecated_model = tolerate_deprecated_model
899+
self.container_entry_point = container_entry_point
900+
self.container_arguments = container_arguments
901+
self.disable_output_compression = disable_output_compression
893902

894903

895904
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,11 +520,9 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
520520

521521
"""If you add arguments to <Estimator constructor>, this test will fail.
522522
Please add the new argument to the skip set below,
523-
and cut a ticket sev-3 to JumpStart team: AWS > SageMaker > JumpStart"""
523+
and reach out to JumpStart team."""
524524

525-
init_args_to_skip: Set[str] = set(
526-
["container_entry_point", "container_arguments", "disable_output_compression", "kwargs"]
527-
)
525+
init_args_to_skip: Set[str] = set(["kwargs"])
528526
fit_args_to_skip: Set[str] = set()
529527
deploy_args_to_skip: Set[str] = set(["kwargs"])
530528

0 commit comments

Comments
 (0)