Skip to content

feat: Added register step in Jumpstart model #4212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,19 @@
JUMPSTART_DEFAULT_REGION_NAME,
JUMPSTART_LOGGER,
)
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import (
JumpStartModelDeployKwargs,
JumpStartModelInitKwargs,
JumpStartModelRegisterKwargs,
)
from sagemaker.jumpstart.utils import (
update_dict_if_key_not_present,
resolve_model_sagemaker_config_field,
verify_model_region_and_return_specs,
)

from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
Expand Down Expand Up @@ -507,6 +512,87 @@ def get_deploy_kwargs(
return deploy_kwargs


def get_register_kwargs(
model_id: str,
model_version: Optional[str] = None,
region: Optional[str] = None,
tolerate_deprecated_model: Optional[bool] = None,
tolerate_vulnerable_model: Optional[bool] = None,
sagemaker_session: Optional[Any] = None,
supported_content_types: List[str] = None,
response_types: List[str] = None,
inference_instances: Optional[List[str]] = None,
transform_instances: Optional[List[str]] = None,
model_package_group_name: Optional[str] = None,
image_uri: Optional[str] = None,
model_metrics: Optional[ModelMetrics] = None,
metadata_properties: Optional[MetadataProperties] = None,
approval_status: Optional[str] = None,
description: Optional[str] = None,
drift_check_baselines: Optional[DriftCheckBaselines] = None,
customer_metadata_properties: Optional[Dict[str, str]] = None,
validation_specification: Optional[str] = None,
domain: Optional[str] = None,
task: Optional[str] = None,
sample_payload_url: Optional[str] = None,
framework: Optional[str] = None,
framework_version: Optional[str] = None,
nearest_model_name: Optional[str] = None,
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
) -> JumpStartModelRegisterKwargs:
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""

register_kwargs = JumpStartModelRegisterKwargs(
model_id=model_id,
model_version=model_version,
region=region,
tolerate_deprecated_model=tolerate_deprecated_model,
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
content_types=supported_content_types,
response_types=response_types,
inference_instances=inference_instances,
transform_instances=transform_instances,
model_package_group_name=model_package_group_name,
image_uri=image_uri,
model_metrics=model_metrics,
metadata_properties=metadata_properties,
approval_status=approval_status,
description=description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
validation_specification=validation_specification,
domain=domain,
task=task,
sample_payload_url=sample_payload_url,
framework=framework,
framework_version=framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
region=region,
scope=JumpStartScriptScope.INFERENCE,
sagemaker_session=sagemaker_session,
tolerate_deprecated_model=tolerate_deprecated_model,
tolerate_vulnerable_model=tolerate_vulnerable_model,
)

register_kwargs.content_types = (
register_kwargs.content_types or model_specs.predictor_specs.supported_content_types
)
register_kwargs.response_types = (
register_kwargs.response_types or model_specs.predictor_specs.supported_accept_types
)
Comment on lines +586 to +591
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, all we're really doing is just setting the content type and accept types from the metadata


return register_kwargs


def get_init_kwargs(
model_id: str,
model_from_estimator: bool = False,
Expand Down
172 changes: 148 additions & 24 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"""This module stores JumpStart implementation of Model class."""

from __future__ import absolute_import
import re

from typing import Dict, List, Optional, Union
from sagemaker import payloads
Expand All @@ -28,16 +27,23 @@
get_default_predictor,
get_deploy_kwargs,
get_init_kwargs,
get_register_kwargs,
)
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import is_valid_model_id
from sagemaker.utils import stringify_object
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
from sagemaker.model import (
Model,
ModelPackage,
)
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
from sagemaker.predictor import PredictorBase
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
from sagemaker.session import Session
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.drift_check_baselines import DriftCheckBaselines


class JumpStartModel(Model):
Expand Down Expand Up @@ -309,11 +315,12 @@ def _is_valid_model_id_hook():
self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
self.region = model_init_kwargs.region
self.model_package_arn = model_init_kwargs.model_package_arn
self.sagemaker_session = model_init_kwargs.sagemaker_session

super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())

self.model_package_arn = model_init_kwargs.model_package_arn

def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]:
"""Returns all example payloads associated with the model.

Expand Down Expand Up @@ -390,30 +397,29 @@ def _create_sagemaker_model(
# inference endpoint.
if self.model_package_arn and not self._model_data_is_set:
# When a ModelPackageArn is provided we just create the Model
match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn)
if match:
model_package_name = match.group(3)
else:
# model_package_arn can be just the name if your account owns the Model Package
model_package_name = self.model_package_arn
container_def = {"ModelPackageName": self.model_package_arn}

if self.env != {}:
container_def["Environment"] = self.env

if self.name is None:
self._base_name = model_package_name

self._set_model_name_if_needed()

self.sagemaker_session.create_model(
self.name,
self.role,
container_def,
model_package = ModelPackage(
role=self.role,
model_data=self.model_data,
model_package_arn=self.model_package_arn,
sagemaker_session=self.sagemaker_session,
predictor_cls=self.predictor_cls,
vpc_config=self.vpc_config,
enable_network_isolation=self.enable_network_isolation(),
)
if self.name is not None:
model_package.name = self.name
if self.env is not None:
model_package.env = self.env
model_package._create_sagemaker_model(
instance_type=instance_type,
accelerator_type=accelerator_type,
tags=tags,
serverless_inference_config=serverless_inference_config,
**kwargs,
)
if self._base_name is None and model_package._base_name is not None:
self._base_name = model_package._base_name
if self.name is None and model_package.name is not None:
self.name = model_package.name
else:
super(JumpStartModel, self)._create_sagemaker_model(
instance_type=instance_type,
Expand Down Expand Up @@ -565,6 +571,124 @@ def deploy(
# If a predictor class was passed, do not mutate predictor
return predictor

def register(
self,
content_types: List[Union[str, PipelineVariable]] = None,
response_types: List[Union[str, PipelineVariable]] = None,
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
model_metrics: Optional[ModelMetrics] = None,
metadata_properties: Optional[MetadataProperties] = None,
approval_status: Optional[Union[str, PipelineVariable]] = None,
description: Optional[str] = None,
drift_check_baselines: Optional[DriftCheckBaselines] = None,
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
validation_specification: Optional[Union[str, PipelineVariable]] = None,
domain: Optional[Union[str, PipelineVariable]] = None,
task: Optional[Union[str, PipelineVariable]] = None,
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
framework: Optional[Union[str, PipelineVariable]] = None,
framework_version: Optional[Union[str, PipelineVariable]] = None,
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Args:
content_types (list[str] or list[PipelineVariable]): The supported MIME types
for the input data.
response_types (list[str] or list[PipelineVariable]): The supported MIME types
for the output data.
inference_instances (list[str] or list[PipelineVariable]): A list of the instance
types that are used to generate inferences in real-time (default: None).
transform_instances (list[str] or list[PipelineVariable]): A list of the instance types
on which a transformation job can be run or on which an endpoint can be deployed
(default: None).
model_package_group_name (str or PipelineVariable): Model Package Group name,
exclusive to `model_package_name`, using `model_package_group_name` makes the
Model Package versioned. Defaults to ``None``.
image_uri (str or PipelineVariable): Inference image URI for the container. Model class'
self.image will be used if it is None. Defaults to ``None``.
model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``.
metadata_properties (MetadataProperties): MetadataProperties object.
Defaults to ``None``.
approval_status (str or PipelineVariable): Model Approval Status, values can be
"Approved", "Rejected", or "PendingManualApproval". Defaults to
``PendingManualApproval``.
description (str): Model Package description. Defaults to ``None``.
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]):
A dictionary of key-value paired metadata properties (default: None).
domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION",
"NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None).
sample_payload_url (str or PipelineVariable): The S3 path where the sample payload
is stored (default: None).
task (str or PipelineVariable): Task values which are supported by Inference Recommender
are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION",
"IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
framework (str or PipelineVariable): Machine learning framework of the model package
container image (default: None).
framework_version (str or PipelineVariable): Framework version of the Model Package
Container Image (default: None).
nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning
benchmarked by Amazon SageMaker Inference Recommender (default: None).
data_input_configuration (str or PipelineVariable): Input object for the model
(default: None).
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
validation. Values can be "All" or "None" (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
"""

register_kwargs = get_register_kwargs(
model_id=self.model_id,
model_version=self.model_version,
region=self.region,
tolerate_deprecated_model=self.tolerate_deprecated_model,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
sagemaker_session=self.sagemaker_session,
supported_content_types=content_types,
response_types=response_types,
inference_instances=inference_instances,
transform_instances=transform_instances,
model_package_group_name=model_package_group_name,
image_uri=image_uri,
model_metrics=model_metrics,
metadata_properties=metadata_properties,
approval_status=approval_status,
description=description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
validation_specification=validation_specification,
domain=domain,
task=task,
sample_payload_url=sample_payload_url,
framework=framework,
framework_version=framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
)

model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())

def register_deploy_wrapper(*args, **kwargs):
if self.model_package_arn is not None:
return self.deploy(*args, **kwargs)

self.model_package_arn = model_package.model_package_arn
predictor = self.deploy(*args, **kwargs)
self.model_package_arn = None
return predictor

model_package.deploy = register_deploy_wrapper

return model_package

def __str__(self) -> str:
"""Overriding str(*) method to make more human-readable."""
return stringify_object(self)
Loading