Skip to content

feat: infer framework and version #3247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,108 @@ def __init__(

self.model_server_workers = model_server_workers

def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status=None,
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
domain=None,
sample_payload_url=None,
task=None,
framework=None,
framework_version=None,
nearest_model_name=None,
data_input_configuration=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Args:
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
`model_package_name`, using `model_package_group_name` makes the Model Package
versioned (default: None).
image_uri (str): Inference image uri for the container. Model class' self.image will
be used if it is None (default: None).
model_metrics (ModelMetrics): ModelMetrics object (default: None).
metadata_properties (MetadataProperties): MetadataProperties (default: None).
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
for AWS Marketplace (default: False).
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).
sample_payload_url (str): The S3 path where the sample payload is stored
(default: None).
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
framework (str): Machine learning framework of the model package container image
(default: None).
framework_version (str): Framework version of the Model Package Container Image
(default: None).
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
Amazon SageMaker Inference Recommender (default: None).
data_input_configuration (str): Input object for the model (default: None).

Returns:
str: A string of SageMaker Model Package ARN.
"""
instance_type = inference_instances[0]
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
self.image_uri = image_uri
if not self.image_uri:
self.image_uri = self.serving_image_uri(
region_name=self.sagemaker_session.boto_session.region_name,
instance_type=instance_type,
)
return super(ChainerModel, self).register(
content_types,
response_types,
inference_instances,
transform_instances,
model_package_name,
model_package_group_name,
image_uri,
model_metrics,
metadata_properties,
marketplace_cert,
approval_status,
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
domain=domain,
sample_payload_url=sample_payload_url,
task=task,
framework=(framework or self._framework_name).upper(),
framework_version=framework_version or self.framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
)

def prepare_container_def(
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
):
Expand Down
20 changes: 18 additions & 2 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def _validate_pt_tf_versions(pytorch_version, tensorflow_version, image_uri):
)


def fetch_framework_and_framework_version(tensorflow_version, pytorch_version):
"""Function to check the framework used in HuggingFace class"""

if tensorflow_version is not None: # pylint: disable=no-member
return ("tensorflow", tensorflow_version) # pylint: disable=no-member
return ("pytorch", pytorch_version) # pylint: disable=no-member


class HuggingFaceModel(FrameworkModel):
"""A Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""

Expand Down Expand Up @@ -387,8 +395,16 @@ def register(
domain=domain,
sample_payload_url=sample_payload_url,
task=task,
framework=framework,
framework_version=framework_version,
framework=(
framework
or fetch_framework_and_framework_version(
self.tensorflow_version, self.pytorch_version
)[0]
).upper(),
framework_version=framework_version
or fetch_framework_and_framework_version(self.tensorflow_version, self.pytorch_version)[
1
],
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
)
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,12 @@ def register(

if model_package_group_name is not None:
container_def = self.prepare_container_def()
update_container_with_inference_params(
container_def = update_container_with_inference_params(
framework=framework,
framework_version=framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
container_obj=container_def,
container_def=container_def,
)
else:
container_def = {
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def register(
domain=domain,
sample_payload_url=sample_payload_url,
task=task,
framework=framework,
framework_version=framework_version,
framework=(framework or self._framework_name).upper(),
framework_version=framework_version or self.framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
)
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,12 @@ def register(
container_def = self.pipeline_container_def(
inference_instances[0] if inference_instances else None
)
update_container_with_inference_params(
container_def = update_container_with_inference_params(
framework=framework,
framework_version=framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
container_list=container_def,
container_def=container_def,
)
else:
container_def = [
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def register(
domain=domain,
sample_payload_url=sample_payload_url,
task=task,
framework=framework,
framework_version=framework_version,
framework=(framework or self._framework_name).upper(),
framework_version=framework_version or self.framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
)
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ def register(
domain=domain,
sample_payload_url=sample_payload_url,
task=task,
framework=framework,
framework_version=framework_version,
framework=(framework or self._framework_name).upper(),
framework_version=framework_version or self.framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
)
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def register(
domain=domain,
sample_payload_url=sample_payload_url,
task=task,
framework=framework,
framework_version=framework_version,
framework=(framework or self._framework_name).upper(),
framework_version=framework_version or self.framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
)
Expand Down
62 changes: 39 additions & 23 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def update_container_with_inference_params(
framework_version=None,
nearest_model_name=None,
data_input_configuration=None,
container_obj=None,
container_def=None,
container_list=None,
):
"""Function to check if inference recommender parameters exist and update container.
Expand All @@ -752,28 +752,30 @@ def update_container_with_inference_params(
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
Amazon SageMaker Inference Recommender (default: None).
data_input_configuration (str): Input object for the model (default: None).
container_obj (dict): object to be updated.
container_def (dict): object to be updated.
container_list (list): list to be updated.

Returns:
dict: dict with inference recommender params
"""

if framework is not None and framework_version is not None and nearest_model_name is not None:
if container_list is not None:
for obj in container_list:
construct_container_object(
obj, data_input_configuration, framework, framework_version, nearest_model_name
)
if container_obj is not None:
if container_list is not None:
for obj in container_list:
construct_container_object(
container_obj,
data_input_configuration,
framework,
framework_version,
nearest_model_name,
obj, data_input_configuration, framework, framework_version, nearest_model_name
)

if container_def is not None:
construct_container_object(
container_def,
data_input_configuration,
framework,
framework_version,
nearest_model_name,
)

return container_list or container_def


def construct_container_object(
obj, data_input_configuration, framework, framework_version, nearest_model_name
Expand All @@ -788,20 +790,32 @@ def construct_container_object(
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
Amazon SageMaker Inference Recommender (default: None).
data_input_configuration (str): Input object for the model (default: None).
container_obj (dict): object to be updated.
container_list (list): list to be updated.
obj (dict): object to be updated.

Returns:
dict: container object
"""

obj.update(
{
"Framework": framework,
"FrameworkVersion": framework_version,
"NearestModelName": nearest_model_name,
}
)
if framework is not None:
obj.update(
{
"Framework": framework,
}
)

if framework_version is not None:
obj.update(
{
"FrameworkVersion": framework_version,
}
)

if nearest_model_name is not None:
obj.update(
{
"NearestModelName": nearest_model_name,
}
)

if data_input_configuration is not None:
obj.update(
Expand All @@ -811,3 +825,5 @@ def construct_container_object(
},
}
)

return obj
2 changes: 1 addition & 1 deletion src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __init__(
)
]

update_container_with_inference_params(
self.container_def_list = update_container_with_inference_params(
framework=framework,
framework_version=framework_version,
nearest_model_name=nearest_model_name,
Expand Down
Loading