Skip to content

Commit 09b6b05

Browse files
committed
add: auto infer framework and version for inference recommender
1 parent 08f40e2 commit 09b6b05

File tree

15 files changed

+524
-57
lines changed

15 files changed

+524
-57
lines changed

src/sagemaker/chainer/model.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def register(
150150
model_package_group_name=None,
151151
image_uri=None,
152152
model_metrics=None,
153+
metadata_properties=None,
153154
marketplace_cert=False,
154155
approval_status=None,
155156
description=None,
@@ -180,11 +181,29 @@ def register(
180181
image_uri (str): Inference image uri for the container. Model class' self.image will
181182
be used if it is None (default: None).
182183
model_metrics (ModelMetrics): ModelMetrics object (default: None).
184+
metadata_properties (MetadataProperties): MetadataProperties (default: None).
183185
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
184186
for AWS Marketplace (default: False).
185187
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
186188
or "PendingManualApproval" (default: "PendingManualApproval").
187189
description (str): Model Package description (default: None).
190+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
191+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
192+
metadata properties (default: None).
193+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
194+
"MACHINE_LEARNING" (default: None).
195+
sample_payload_url (str): The S3 path where the sample payload is stored
196+
(default: None).
197+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
198+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
199+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
200+
framework (str): Machine learning framework of the model package container image
201+
(default: None).
202+
framework_version (str): Framework version of the Model Package Container Image
203+
(default: None).
204+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
205+
Amazon SageMaker Inference Recommender (default: None).
206+
data_input_configuration (str): Input object for the model (default: None).
188207
189208
Returns:
190209
str: A string of SageMaker Model Package ARN.
@@ -208,9 +227,10 @@ def register(
208227
model_package_group_name,
209228
image_uri,
210229
model_metrics,
230+
metadata_properties,
211231
marketplace_cert,
212232
approval_status,
213-
description,
233+
description,
214234
drift_check_baselines=drift_check_baselines,
215235
customer_metadata_properties=customer_metadata_properties,
216236
domain=domain,

src/sagemaker/huggingface/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,13 @@ def _validate_pt_tf_versions(pytorch_version, tensorflow_version, image_uri):
8686

8787

8888
def fetch_framework_and_framework_version(tensorflow_version, pytorch_version):
89+
"""Function to check the framework used in HuggingFace class"""
90+
8991
if tensorflow_version is not None: # pylint: disable=no-member
9092
return ("tensorflow", tensorflow_version) # pylint: disable=no-member
9193
else:
9294
return ("pytorch", pytorch_version) # pylint: disable=no-member
93-
95+
9496

9597
class HuggingFaceModel(FrameworkModel):
9698
"""A Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
@@ -296,7 +298,6 @@ def deploy(
296298
serverless_inference_config,
297299
)
298300

299-
300301
def register(
301302
self,
302303
content_types,
@@ -395,8 +396,14 @@ def register(
395396
domain=domain,
396397
sample_payload_url=sample_payload_url,
397398
task=task,
398-
framework=framework or fetch_framework_and_framework_version()[0],
399-
framework_version=framework_version or fetch_framework_and_framework_version()[1],
399+
framework=framework
400+
or fetch_framework_and_framework_version(self.tensorflow_version, self.pytorch_version)[
401+
0
402+
],
403+
framework_version=framework_version
404+
or fetch_framework_and_framework_version(self.tensorflow_version, self.pytorch_version)[
405+
1
406+
],
400407
nearest_model_name=nearest_model_name,
401408
data_input_configuration=data_input_configuration,
402409
)

src/sagemaker/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,12 @@ def register(
374374

375375
if model_package_group_name is not None:
376376
container_def = self.prepare_container_def()
377-
update_container_with_inference_params(
377+
container_def = update_container_with_inference_params(
378378
framework=framework,
379379
framework_version=framework_version,
380380
nearest_model_name=nearest_model_name,
381381
data_input_configuration=data_input_configuration,
382-
container_obj=container_def,
382+
container_def=container_def,
383383
)
384384
else:
385385
container_def = {

src/sagemaker/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,12 @@ def register(
340340
container_def = self.pipeline_container_def(
341341
inference_instances[0] if inference_instances else None
342342
)
343-
update_container_with_inference_params(
343+
container_def = update_container_with_inference_params(
344344
framework=framework,
345345
framework_version=framework_version,
346346
nearest_model_name=nearest_model_name,
347347
data_input_configuration=data_input_configuration,
348-
container_list=container_def,
348+
container_def=container_def,
349349
)
350350
else:
351351
container_def = [

src/sagemaker/utils.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def update_container_with_inference_params(
739739
framework_version=None,
740740
nearest_model_name=None,
741741
data_input_configuration=None,
742-
container_obj=None,
742+
container_def=None,
743743
container_list=None,
744744
):
745745
"""Function to check if inference recommender parameters exist and update container.
@@ -752,28 +752,30 @@ def update_container_with_inference_params(
752752
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
753753
Amazon SageMaker Inference Recommender (default: None).
754754
data_input_configuration (str): Input object for the model (default: None).
755-
container_obj (dict): object to be updated.
755+
container_def (dict): object to be updated.
756756
container_list (list): list to be updated.
757757
758758
Returns:
759759
dict: dict with inference recommender params
760760
"""
761761

762-
if framework is not None and framework_version is not None and nearest_model_name is not None:
763-
if container_list is not None:
764-
for obj in container_list:
765-
construct_container_object(
766-
obj, data_input_configuration, framework, framework_version, nearest_model_name
767-
)
768-
if container_obj is not None:
762+
if container_list is not None:
763+
for obj in container_list:
769764
construct_container_object(
770-
container_obj,
771-
data_input_configuration,
772-
framework,
773-
framework_version,
774-
nearest_model_name,
765+
obj, data_input_configuration, framework, framework_version, nearest_model_name
775766
)
776767

768+
if container_def is not None:
769+
construct_container_object(
770+
container_def,
771+
data_input_configuration,
772+
framework,
773+
framework_version,
774+
nearest_model_name,
775+
)
776+
777+
return container_list or container_def
778+
777779

778780
def construct_container_object(
779781
obj, data_input_configuration, framework, framework_version, nearest_model_name
@@ -788,20 +790,32 @@ def construct_container_object(
788790
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
789791
Amazon SageMaker Inference Recommender (default: None).
790792
data_input_configuration (str): Input object for the model (default: None).
791-
container_obj (dict): object to be updated.
792-
container_list (list): list to be updated.
793+
obj (dict): object to be updated.
793794
794795
Returns:
795796
dict: container object
796797
"""
797798

798-
obj.update(
799-
{
800-
"Framework": framework,
801-
"FrameworkVersion": framework_version,
802-
"NearestModelName": nearest_model_name,
803-
}
804-
)
799+
if framework is not None:
800+
obj.update(
801+
{
802+
"Framework": framework,
803+
}
804+
)
805+
806+
if framework_version is not None:
807+
obj.update(
808+
{
809+
"FrameworkVersion": framework_version,
810+
}
811+
)
812+
813+
if nearest_model_name is not None:
814+
obj.update(
815+
{
816+
"NearestModelName": nearest_model_name,
817+
}
818+
)
805819

806820
if data_input_configuration is not None:
807821
obj.update(
@@ -811,3 +825,5 @@ def construct_container_object(
811825
},
812826
}
813827
)
828+
829+
return obj

src/sagemaker/workflow/step_collections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def __init__(
250250
)
251251
]
252252

253-
update_container_with_inference_params(
253+
self.container_def_list = update_container_with_inference_params(
254254
framework=framework,
255255
framework_version=framework_version,
256256
nearest_model_name=nearest_model_name,

src/sagemaker/xgboost/model.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,13 @@ def register(
128128
self,
129129
content_types,
130130
response_types,
131-
inference_instances,
132-
transform_instances,
131+
inference_instances=None,
132+
transform_instances=None,
133133
model_package_name=None,
134134
model_package_group_name=None,
135135
image_uri=None,
136136
model_metrics=None,
137+
metadata_properties=None,
137138
marketplace_cert=False,
138139
approval_status=None,
139140
description=None,
@@ -164,11 +165,29 @@ def register(
164165
image_uri (str): Inference image uri for the container. Model class' self.image will
165166
be used if it is None (default: None).
166167
model_metrics (ModelMetrics): ModelMetrics object (default: None).
168+
metadata_properties (MetadataProperties): MetadataProperties (default: None).
167169
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
168170
for AWS Marketplace (default: False).
169171
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
170172
or "PendingManualApproval" (default: "PendingManualApproval").
171173
description (str): Model Package description (default: None).
174+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
175+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
176+
metadata properties (default: None).
177+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
178+
"MACHINE_LEARNING" (default: None).
179+
sample_payload_url (str): The S3 path where the sample payload is stored
180+
(default: None).
181+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
182+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
183+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
184+
framework (str): Machine learning framework of the model package container image
185+
(default: None).
186+
framework_version (str): Framework version of the Model Package Container Image
187+
(default: None).
188+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
189+
Amazon SageMaker Inference Recommender (default: None).
190+
data_input_configuration (str): Input object for the model (default: None).
172191
173192
Returns:
174193
str: A string of SageMaker Model Package ARN.
@@ -192,6 +211,7 @@ def register(
192211
model_package_group_name,
193212
image_uri,
194213
model_metrics,
214+
metadata_properties,
195215
marketplace_cert,
196216
approval_status,
197217
description,

tests/unit/sagemaker/tensorflow/test_tfs.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import mock
2020
import pytest
21-
from mock import Mock, patch
21+
from mock import Mock, patch, ANY
2222

2323
from sagemaker.serializers import CSVSerializer, IdentitySerializer
2424
from sagemaker.tensorflow import TensorFlow, TensorFlowModel, TensorFlowPredictor
@@ -454,3 +454,51 @@ def mock_response(expected_response, sagemaker_session, content_type=JSON_CONTEN
454454
"ContentType": content_type,
455455
"Body": io.BytesIO(expected_response),
456456
}
457+
458+
459+
def test_register_tfs_model_auto_infer_framework(sagemaker_session, tensorflow_inference_version):
460+
model_package_group_name = "test-tfs-register-model"
461+
content_types = ["application/json"]
462+
response_types = ["application/json"]
463+
inference_instances = ["ml.m4.xlarge"]
464+
transform_instances = ["ml.m4.xlarge"]
465+
image_uri = "fakeimage"
466+
467+
tfs_model = TensorFlowModel(
468+
"s3://some/data.tar.gz",
469+
role=ROLE,
470+
framework_version=tensorflow_inference_version,
471+
sagemaker_session=sagemaker_session,
472+
)
473+
474+
tfs_model.register(
475+
content_types,
476+
response_types,
477+
inference_instances,
478+
transform_instances,
479+
model_package_group_name=model_package_group_name,
480+
marketplace_cert=True,
481+
image_uri=image_uri,
482+
)
483+
484+
expected_create_model_package_request = {
485+
"containers": [
486+
{
487+
"Image": image_uri,
488+
"Environment": ANY,
489+
"ModelDataUrl": ANY,
490+
"Framework": "tensorflow",
491+
"FrameworkVersion": tensorflow_inference_version,
492+
},
493+
],
494+
"content_types": content_types,
495+
"response_types": response_types,
496+
"inference_instances": inference_instances,
497+
"transform_instances": transform_instances,
498+
"model_package_group_name": model_package_group_name,
499+
"marketplace_cert": True,
500+
}
501+
502+
sagemaker_session.create_model_package_from_containers.assert_called_with(
503+
**expected_create_model_package_request
504+
)

0 commit comments

Comments
 (0)