Skip to content

Commit 55ddded

Browse files
authored
Merge branch 'master' into master
2 parents 5c34b9d + ddd06bb commit 55ddded

File tree

20 files changed

+412
-42
lines changed

20 files changed

+412
-42
lines changed

src/sagemaker/chainer/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def __init__(
148148

149149
def register(
150150
self,
151-
content_types: List[Union[str, PipelineVariable]],
152-
response_types: List[Union[str, PipelineVariable]],
151+
content_types: List[Union[str, PipelineVariable]] = None,
152+
response_types: List[Union[str, PipelineVariable]] = None,
153153
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
154154
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
155155
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,8 +1665,8 @@ def deploy(
16651665

16661666
def register(
16671667
self,
1668-
content_types,
1669-
response_types,
1668+
content_types=None,
1669+
response_types=None,
16701670
inference_instances=None,
16711671
transform_instances=None,
16721672
image_uri=None,

src/sagemaker/huggingface/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ def deploy(
332332

333333
def register(
334334
self,
335-
content_types: List[Union[str, PipelineVariable]],
336-
response_types: List[Union[str, PipelineVariable]],
335+
content_types: List[Union[str, PipelineVariable]] = None,
336+
response_types: List[Union[str, PipelineVariable]] = None,
337337
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
338338
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
339339
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/jumpstart/factory/model.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores JumpStart Model factory methods."""
1414
from __future__ import absolute_import
15+
import json
1516

1617

1718
from typing import Any, Dict, List, Optional, Union
@@ -206,9 +207,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
206207
def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
207208
"""Sets model data based on default or override, returns full kwargs."""
208209

209-
model_data = kwargs.model_data
210-
211-
kwargs.model_data = model_data or model_uris.retrieve(
210+
model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve(
212211
model_scope=JumpStartScriptScope.INFERENCE,
213212
model_id=kwargs.model_id,
214213
model_version=kwargs.model_version,
@@ -218,6 +217,25 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
218217
sagemaker_session=kwargs.sagemaker_session,
219218
)
220219

220+
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):
221+
old_model_data_str = model_data
222+
model_data = {
223+
"S3DataSource": {
224+
"S3Uri": model_data,
225+
"S3DataType": "S3Prefix",
226+
"CompressionType": "None",
227+
}
228+
}
229+
if kwargs.model_data:
230+
JUMPSTART_LOGGER.info(
231+
"S3 prefix model_data detected for JumpStartModel: '%s'. "
232+
"Converting to S3DataSource dictionary: '%s'.",
233+
old_model_data_str,
234+
json.dumps(model_data),
235+
)
236+
237+
kwargs.model_data = model_data
238+
221239
return kwargs
222240

223241

@@ -496,7 +514,7 @@ def get_init_kwargs(
496514
instance_type: Optional[str] = None,
497515
region: Optional[str] = None,
498516
image_uri: Optional[Union[str, PipelineVariable]] = None,
499-
model_data: Optional[Union[str, PipelineVariable]] = None,
517+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
500518
role: Optional[str] = None,
501519
predictor_cls: Optional[callable] = None,
502520
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,

src/sagemaker/jumpstart/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
region: Optional[str] = None,
5454
instance_type: Optional[str] = None,
5555
image_uri: Optional[Union[str, PipelineVariable]] = None,
56-
model_data: Optional[Union[str, PipelineVariable]] = None,
56+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
5757
role: Optional[str] = None,
5858
predictor_cls: Optional[callable] = None,
5959
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
@@ -95,8 +95,8 @@ def __init__(
9595
instance_type (Optional[str]): The EC2 instance type to use when provisioning a hosting
9696
endpoint. (Default: None).
9797
image_uri (Optional[Union[str, PipelineVariable]]): A Docker image URI. (Default: None).
98-
model_data (Optional[Union[str, PipelineVariable]]): The S3 location of a SageMaker
99-
model data ``.tar.gz`` file. (Default: None).
98+
model_data (Optional[Union[str, PipelineVariable, dict]]): Location
99+
of SageMaker model data. (Default: None).
100100
role (Optional[str]): An AWS IAM role (either name or full ARN). The Amazon
101101
SageMaker training jobs and APIs that create Amazon SageMaker
102102
endpoints use this role to access training data and model

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def __init__(
752752
region: Optional[str] = None,
753753
instance_type: Optional[str] = None,
754754
image_uri: Optional[Union[str, Any]] = None,
755-
model_data: Optional[Union[str, Any]] = None,
755+
model_data: Optional[Union[str, Any, dict]] = None,
756756
role: Optional[str] = None,
757757
predictor_cls: Optional[callable] = None,
758758
env: Optional[Dict[str, Union[str, Any]]] = None,

src/sagemaker/model.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
4444
load_sagemaker_config,
4545
)
46+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
4647
from sagemaker.session import Session
4748
from sagemaker.model_metrics import ModelMetrics
4849
from sagemaker.deprecations import removed_kwargs
@@ -374,12 +375,14 @@ def __init__(
374375
self.dependencies = updates["dependencies"]
375376
self.uploaded_code = None
376377
self.repacked_model_data = None
378+
self.content_types = None
379+
self.response_types = None
377380

378381
@runnable_by_pipeline
379382
def register(
380383
self,
381-
content_types: List[Union[str, PipelineVariable]],
382-
response_types: List[Union[str, PipelineVariable]],
384+
content_types: List[Union[str, PipelineVariable]] = None,
385+
response_types: List[Union[str, PipelineVariable]] = None,
383386
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
384387
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
385388
model_package_name: Optional[Union[str, PipelineVariable]] = None,
@@ -456,16 +459,33 @@ def register(
456459
in case the Model instance is built with
457460
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
458461
"""
459-
if self.model_data is None:
460-
raise ValueError("SageMaker Model Package cannot be created without model data.")
461462
if isinstance(self.model_data, dict):
462463
raise ValueError(
463464
"SageMaker Model Package currently cannot be created with ModelDataSource."
464465
)
465466

467+
if content_types is not None:
468+
self.content_types = content_types
469+
470+
if response_types is not None:
471+
self.response_types = response_types
472+
473+
if self.content_types is None:
474+
raise ValueError("The supported MIME types for the input data is not set")
475+
476+
if self.response_types is None:
477+
raise ValueError("The supported MIME types for the output data is not set")
478+
466479
if image_uri is not None:
467480
self.image_uri = image_uri
468481

482+
if model_package_group_name is None and model_package_name is None:
483+
# If model package group and model package name is not set
484+
# then register to auto-generated model package group
485+
model_package_group_name = utils.base_name_from_image(
486+
self.image_uri, default_base_name=ModelPackage.__name__
487+
)
488+
469489
if model_package_group_name is not None:
470490
container_def = self.prepare_container_def()
471491
container_def = update_container_with_inference_params(
@@ -478,12 +498,14 @@ def register(
478498
else:
479499
container_def = {
480500
"Image": self.image_uri,
481-
"ModelDataUrl": self.model_data,
482501
}
483502

503+
if self.model_data is not None:
504+
container_def["ModelDataUrl"] = self.model_data
505+
484506
model_pkg_args = sagemaker.get_model_package_args(
485-
content_types,
486-
response_types,
507+
self.content_types,
508+
self.response_types,
487509
inference_instances=inference_instances,
488510
transform_instances=transform_instances,
489511
model_package_name=model_package_name,
@@ -511,6 +533,7 @@ def register(
511533
role=self.role,
512534
model_data=self.model_data,
513535
model_package_arn=model_package.get("ModelPackageArn"),
536+
sagemaker_session=self.sagemaker_session,
514537
)
515538

516539
@runnable_by_pipeline
@@ -1751,6 +1774,7 @@ def __init__(
17511774

17521775
# works for MODEL_PACKAGE_ARN with or without version info.
17531776
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"
1777+
MODEL_PACKAGE_VERSIONED_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)/(\d+)$"
17541778

17551779

17561780
class ModelPackage(Model):
@@ -1885,6 +1909,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
18851909
self._ensure_base_name_if_needed(model_package_name)
18861910
self._set_model_name_if_needed()
18871911

1912+
# Quering the approval status for the model package
1913+
# Approving the versioned model package in case it is not approved
1914+
model_package_desc = self.sagemaker_session.sagemaker_client.describe_model_package(
1915+
ModelPackageName=self.model_package_arn or model_package_name
1916+
)
1917+
if self.model_package_arn is None:
1918+
self.model_package_arn = model_package_desc["ModelPackageArn"]
1919+
if re.match(MODEL_PACKAGE_VERSIONED_ARN_PATTERN, self.model_package_arn):
1920+
approval_status = model_package_desc.get("ModelApprovalStatus", "")
1921+
if approval_status != ModelApprovalStatusEnum.APPROVED:
1922+
self.update_approval_status(approval_status=ModelApprovalStatusEnum.APPROVED)
1923+
18881924
self.sagemaker_session.create_model(
18891925
self.name,
18901926
self.role,
@@ -1898,3 +1934,29 @@ def _ensure_base_name_if_needed(self, base_name):
18981934
"""Set the base name if there is no model name provided."""
18991935
if self.name is None:
19001936
self._base_name = base_name
1937+
1938+
def update_approval_status(self, approval_status, approval_description=None):
1939+
"""Update the approval status for the model package
1940+
1941+
Args:
1942+
approval_status (str or PipelineVariable): Model Approval Status, values can be
1943+
"Approved", "Rejected", or "PendingManualApproval".
1944+
approval_description (str): Optional. Description for the approval status of the model
1945+
(default: None).
1946+
"""
1947+
1948+
# Models can lazy-init sagemaker_session until deploy() is called to support
1949+
# LocalMode so we must make sure we have an actual session
1950+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
1951+
if self.model_package_arn is None:
1952+
raise ValueError("model_package_arn is required to update the status.")
1953+
1954+
update_approval_args = {
1955+
"ModelPackageArn": self.model_package_arn,
1956+
"ModelApprovalStatus": approval_status,
1957+
}
1958+
1959+
if approval_description is not None:
1960+
update_approval_args["ApprovalDescription"] = approval_description
1961+
1962+
sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)

src/sagemaker/mxnet/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def __init__(
150150

151151
def register(
152152
self,
153-
content_types: List[Union[str, PipelineVariable]],
154-
response_types: List[Union[str, PipelineVariable]],
153+
content_types: List[Union[str, PipelineVariable]] = None,
154+
response_types: List[Union[str, PipelineVariable]] = None,
155155
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
156156
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
157157
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def _create_sagemaker_pipeline_model(self, instance_type):
335335
@runnable_by_pipeline
336336
def register(
337337
self,
338-
content_types: List[Union[str, PipelineVariable]],
339-
response_types: List[Union[str, PipelineVariable]],
338+
content_types: List[Union[str, PipelineVariable]] = None,
339+
response_types: List[Union[str, PipelineVariable]] = None,
340340
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
341341
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
342342
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pytorch/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def __init__(
152152

153153
def register(
154154
self,
155-
content_types: List[Union[str, PipelineVariable]],
156-
response_types: List[Union[str, PipelineVariable]],
155+
content_types: List[Union[str, PipelineVariable]] = None,
156+
response_types: List[Union[str, PipelineVariable]] = None,
157157
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
158158
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
159159
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/session.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5830,8 +5830,8 @@ def wait_for_inference_recommendations_job(
58305830

58315831

58325832
def get_model_package_args(
5833-
content_types,
5834-
response_types,
5833+
content_types=None,
5834+
response_types=None,
58355835
inference_instances=None,
58365836
transform_instances=None,
58375837
model_package_name=None,
@@ -5899,19 +5899,23 @@ def get_model_package_args(
58995899
else:
59005900
container = {
59015901
"Image": image_uri,
5902-
"ModelDataUrl": model_data,
59035902
}
5903+
if model_data is not None:
5904+
container["ModelDataUrl"] = model_data
5905+
59045906
containers = [container]
59055907

59065908
model_package_args = {
59075909
"containers": containers,
5908-
"content_types": content_types,
5909-
"response_types": response_types,
59105910
"inference_instances": inference_instances,
59115911
"transform_instances": transform_instances,
59125912
"marketplace_cert": marketplace_cert,
59135913
}
59145914

5915+
if content_types is not None:
5916+
model_package_args["content_types"] = content_types
5917+
if response_types is not None:
5918+
model_package_args["response_types"] = response_types
59155919
if model_package_name is not None:
59165920
model_package_args["model_package_name"] = model_package_name
59175921
if model_package_group_name is not None:

src/sagemaker/sklearn/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def __init__(
145145

146146
def register(
147147
self,
148-
content_types: List[Union[str, PipelineVariable]],
149-
response_types: List[Union[str, PipelineVariable]],
148+
content_types: List[Union[str, PipelineVariable]] = None,
149+
response_types: List[Union[str, PipelineVariable]] = None,
150150
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
151151
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
152152
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207

208208
def register(
209209
self,
210-
content_types: List[Union[str, PipelineVariable]],
211-
response_types: List[Union[str, PipelineVariable]],
210+
content_types: List[Union[str, PipelineVariable]] = None,
211+
response_types: List[Union[str, PipelineVariable]] = None,
212212
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
213213
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
214214
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/workflow/_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,6 @@ def arguments(self) -> RequestType:
443443
model = self.estimator.create_model(**self.kwargs)
444444
self.image_uri = model.image_uri
445445

446-
if self.model_data is None:
447-
self.model_data = model.model_data
448-
449446
# reset placeholder
450447
self.estimator.output_path = output_path
451448

src/sagemaker/xgboost/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def __init__(
133133

134134
def register(
135135
self,
136-
content_types: List[Union[str, PipelineVariable]],
137-
response_types: List[Union[str, PipelineVariable]],
136+
content_types: List[Union[str, PipelineVariable]] = None,
137+
response_types: List[Union[str, PipelineVariable]] = None,
138138
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
139139
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
140140
model_package_name: Optional[Union[str, PipelineVariable]] = None,

0 commit comments

Comments
 (0)