Skip to content

Commit 34f6fb1

Browse files
author
Keshav Chandak
committed
Added update for model package
1 parent 0e664c8 commit 34f6fb1

File tree

4 files changed

+293
-13
lines changed

4 files changed

+293
-13
lines changed

src/sagemaker/model.py

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
)
7676
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
7777
from sagemaker.enums import EndpointType
78+
from sagemaker.session import get_add_model_package_inference_args
7879

7980
LOGGER = logging.getLogger("sagemaker")
8081

@@ -485,12 +486,6 @@ def register(
485486
if response_types is not None:
486487
self.response_types = response_types
487488

488-
if self.content_types is None:
489-
raise ValueError("The supported MIME types for the input data is not set")
490-
491-
if self.response_types is None:
492-
raise ValueError("The supported MIME types for the output data is not set")
493-
494489
if image_uri is not None:
495490
self.image_uri = image_uri
496491

@@ -2152,7 +2147,7 @@ def update_approval_status(self, approval_status, approval_description=None):
21522147
"""Update the approval status for the model package
21532148
21542149
Args:
2155-
approval_status (str or PipelineVariable): Model Approval Status, values can be
2150+
approval_status (str): Model Approval Status, values can be
21562151
"Approved", "Rejected", or "PendingManualApproval".
21572152
approval_description (str): Optional. Description for the approval status of the model
21582153
(default: None).
@@ -2173,3 +2168,96 @@ def update_approval_status(self, approval_status, approval_description=None):
21732168
update_approval_args["ApprovalDescription"] = approval_description
21742169

21752170
sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)
2171+
2172+
def update_customer_metadata(self, customer_metadata_properties: Dict[str, str]):
2173+
"""Updating customer metadata properties for the model package
2174+
2175+
Args:
2176+
customer_metadata_properties (dict[str, str]):
2177+
A dictionary of key-value paired metadata properties (default: None).
2178+
"""
2179+
2180+
update_metadata_args = {
2181+
"ModelPackageArn": self.model_package_arn,
2182+
"CustomerMetadataProperties": customer_metadata_properties,
2183+
}
2184+
2185+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2186+
sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args)
2187+
2188+
def remove_customer_metadata_properties(
2189+
self, customer_metadata_properties_to_remove: List[str]
2190+
):
2191+
"""Removes the specified keys from customer metadata properties
2192+
2193+
Args:
2194+
customer_metadata_properties (list[str, str]):
2195+
list of keys of customer metadata properties to remove.
2196+
"""
2197+
2198+
delete_metadata_args = {
2199+
"ModelPackageArn": self.model_package_arn,
2200+
"CustomerMetadataPropertiesToRemove": customer_metadata_properties_to_remove,
2201+
}
2202+
2203+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2204+
sagemaker_session.sagemaker_client.update_model_package(**delete_metadata_args)
2205+
2206+
def add_inference_specification(
2207+
self,
2208+
name: str,
2209+
containers: Dict = None,
2210+
image_uris: List[str] = None,
2211+
description: str = None,
2212+
content_types: List[str] = None,
2213+
response_types: List[str] = None,
2214+
inference_instances: List[str] = None,
2215+
transform_instances: List[str] = None,
2216+
):
2217+
"""Additional inference specification to be added for the model package
2218+
2219+
Args:
2220+
name (str): Name to identify the additional inference specification
2221+
containers (dict): The Amazon ECR registry path of the Docker image
2222+
that contains the inference code.
2223+
image_uris (List[str]): The ECR path where inference code is stored.
2224+
description (str): Description for the additional inference specification
2225+
content_types (list[str]): The supported MIME types
2226+
for the input data.
2227+
response_types (list[str]): The supported MIME types
2228+
for the output data.
2229+
inference_instances (list[str]): A list of the instance
2230+
types that are used to generate inferences in real-time (default: None).
2231+
transform_instances (list[str]): A list of the instance
2232+
types on which a transformation job can be run or on which an endpoint can be
2233+
deployed (default: None).
2234+
2235+
"""
2236+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2237+
if containers is not None and image_uris is not None:
2238+
raise ValueError("Cannot have both containers and image_uris.")
2239+
if containers is None and image_uris is None:
2240+
raise ValueError("Should have either containers or image_uris for inference.")
2241+
container_def = []
2242+
if image_uris:
2243+
for uri in image_uris:
2244+
container_def.append(
2245+
{
2246+
"Image": uri,
2247+
}
2248+
)
2249+
else:
2250+
container_def = containers
2251+
2252+
model_package_update_args = get_add_model_package_inference_args(
2253+
model_package_arn=self.model_package_arn,
2254+
name=name,
2255+
containers=container_def,
2256+
content_types=content_types,
2257+
description=description,
2258+
response_types=response_types,
2259+
inference_instances=inference_instances,
2260+
transform_instances=transform_instances,
2261+
)
2262+
2263+
sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args)

src/sagemaker/session.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6550,15 +6550,21 @@ def get_create_model_package_request(
65506550
if task is not None:
65516551
request_dict["Task"] = task
65526552
if containers is not None:
6553-
if not all([content_types, response_types]):
6554-
raise ValueError(
6555-
"content_types and response_types " "must be provided if containers is present."
6556-
)
65576553
inference_specification = {
65586554
"Containers": containers,
6559-
"SupportedContentTypes": content_types,
6560-
"SupportedResponseMIMETypes": response_types,
65616555
}
6556+
if content_types is not None:
6557+
inference_specification.update(
6558+
{
6559+
"SupportedContentTypes": content_types,
6560+
}
6561+
)
6562+
if response_types is not None:
6563+
inference_specification.update(
6564+
{
6565+
"SupportedResponseMIMETypes": response_types,
6566+
}
6567+
)
65626568
if model_package_group_name is not None:
65636569
if inference_instances is not None:
65646570
inference_specification.update(
@@ -6591,6 +6597,76 @@ def get_create_model_package_request(
65916597
return request_dict
65926598

65936599

6600+
def get_add_model_package_inference_args(
6601+
model_package_arn,
6602+
name,
6603+
containers=None,
6604+
content_types=None,
6605+
response_types=None,
6606+
inference_instances=None,
6607+
transform_instances=None,
6608+
description=None,
6609+
):
6610+
"""Get request dictionary for UpdateModelPackage API for additional inference.
6611+
6612+
Args:
6613+
model_package_arn (str): Arn for the model package.
6614+
name (str): Name to identify the additional inference specification
6615+
containers (dict): The Amazon ECR registry path of the Docker image
6616+
that contains the inference code.
6617+
image_uris (List[str]): The ECR path where inference code is stored.
6618+
description (str): Description for the additional inference specification
6619+
content_types (list[str]): The supported MIME types
6620+
for the input data.
6621+
response_types (list[str]): The supported MIME types
6622+
for the output data.
6623+
inference_instances (list[str]): A list of the instance
6624+
types that are used to generate inferences in real-time (default: None).
6625+
transform_instances (list[str]): A list of the instance
6626+
types on which a transformation job can be run or on which an endpoint can be
6627+
deployed (default: None).
6628+
"""
6629+
6630+
request_dict = {}
6631+
if containers is not None:
6632+
inference_specification = {
6633+
"Containers": containers,
6634+
}
6635+
6636+
if name is not None:
6637+
inference_specification.update({"Name": name})
6638+
6639+
if description is not None:
6640+
inference_specification.update({"Description": description})
6641+
if content_types is not None:
6642+
inference_specification.update(
6643+
{
6644+
"SupportedContentTypes": content_types,
6645+
}
6646+
)
6647+
if response_types is not None:
6648+
inference_specification.update(
6649+
{
6650+
"SupportedResponseMIMETypes": response_types,
6651+
}
6652+
)
6653+
if inference_instances is not None:
6654+
inference_specification.update(
6655+
{
6656+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
6657+
}
6658+
)
6659+
if transform_instances is not None:
6660+
inference_specification.update(
6661+
{
6662+
"SupportedTransformInstanceTypes": transform_instances,
6663+
}
6664+
)
6665+
request_dict["AdditionalInferenceSpecificationsToAdd"] = [inference_specification]
6666+
request_dict.update({"ModelPackageArn": model_package_arn})
6667+
return request_dict
6668+
6669+
65946670
def update_args(args: Dict[str, Any], **kwargs):
65956671
"""Updates the request arguments dict with the value if populated.
65966672

tests/integ/test_model_package.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sagemaker.utils import unique_name_from_base
1818
from tests.integ import DATA_DIR
1919
from sagemaker.xgboost import XGBoostModel
20+
from sagemaker import image_uris
2021

2122
_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")
2223

@@ -61,3 +62,45 @@ def test_update_approval_model_package(sagemaker_session):
6162
sagemaker_session.sagemaker_client.delete_model_package_group(
6263
ModelPackageGroupName=model_group_name
6364
)
65+
66+
67+
def test_inference_specification_addition(sagemaker_session):
68+
69+
model_group_name = unique_name_from_base("test-model-group")
70+
71+
sagemaker_session.sagemaker_client.create_model_package_group(
72+
ModelPackageGroupName=model_group_name
73+
)
74+
75+
xgb_model_data_s3 = sagemaker_session.upload_data(
76+
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
77+
key_prefix="integ-test-data/xgboost/model",
78+
)
79+
model = XGBoostModel(
80+
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
81+
)
82+
83+
model_package = model.register(
84+
content_types=["text/csv"],
85+
response_types=["text/csv"],
86+
inference_instances=["ml.m5.large"],
87+
transform_instances=["ml.m5.large"],
88+
model_package_group_name=model_group_name,
89+
)
90+
91+
xgb_image = image_uris.retrieve(
92+
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
93+
)
94+
model_package.add_inference_specification(image_uris=[xgb_image], name="Inference")
95+
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
96+
ModelPackageName=model_package.model_package_arn
97+
)
98+
assert len(desc_model_package["AdditionalInferenceSpecifications"]) == 1
99+
assert desc_model_package["AdditionalInferenceSpecifications"][0]["Name"] == "Inference"
100+
101+
sagemaker_session.sagemaker_client.delete_model_package(
102+
ModelPackageName=model_package.model_package_arn
103+
)
104+
sagemaker_session.sagemaker_client.delete_model_package_group(
105+
ModelPackageGroupName=model_group_name
106+
)

tests/unit/sagemaker/model/test_model_package.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,76 @@ def test_model_package_auto_approve_on_deploy(update_approval_status, sagemaker_
326326
update_approval_status.call_args_list[0][1]["approval_status"]
327327
== ModelApprovalStatusEnum.APPROVED
328328
)
329+
330+
331+
def test_update_customer_metadata(sagemaker_session):
332+
model_package = ModelPackage(
333+
role="role",
334+
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
335+
sagemaker_session=sagemaker_session,
336+
)
337+
338+
customer_metadata_to_update = {
339+
"Key": "Value",
340+
}
341+
model_package.update_customer_metadata(customer_metadata_properties=customer_metadata_to_update)
342+
343+
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
344+
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
345+
CustomerMetadataProperties=customer_metadata_to_update,
346+
)
347+
348+
349+
def test_remove_customer_metadata(sagemaker_session):
350+
model_package = ModelPackage(
351+
role="role",
352+
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
353+
sagemaker_session=sagemaker_session,
354+
)
355+
356+
customer_metadata_to_remove = ["Key"]
357+
358+
model_package.remove_customer_metadata_properties(
359+
customer_metadata_properties_to_remove=customer_metadata_to_remove
360+
)
361+
362+
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
363+
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
364+
CustomerMetadataPropertiesToRemove=customer_metadata_to_remove,
365+
)
366+
367+
368+
def test_add_inference_specification(sagemaker_session):
369+
model_package = ModelPackage(
370+
role="role",
371+
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
372+
sagemaker_session=sagemaker_session,
373+
)
374+
375+
image_uris = ["image_uri"]
376+
377+
containers = [{"Image": "image_uri"}]
378+
379+
try:
380+
model_package.add_inference_specification(
381+
image_uris=image_uris, name="Inference", containers=containers
382+
)
383+
except ValueError as ve:
384+
assert "Cannot have both containers and image_uris." in str(ve)
385+
386+
try:
387+
model_package.add_inference_specification(name="Inference")
388+
except ValueError as ve:
389+
assert "Should have either containers or image_uris for inference." in str(ve)
390+
391+
model_package.add_inference_specification(image_uris=image_uris, name="Inference")
392+
393+
sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
394+
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
395+
AdditionalInferenceSpecificationsToAdd=[
396+
{
397+
"Containers": [{"Image": "image_uri"}],
398+
"Name": "Inference",
399+
}
400+
],
401+
)

0 commit comments

Comments
 (0)