Skip to content

Commit 72396cc

Browse files
author
Keshav Chandak
committed
Added register step in Jumpstart
1 parent 744724b commit 72396cc

File tree

6 files changed

+471
-74
lines changed

6 files changed

+471
-74
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,19 @@
3434
JUMPSTART_DEFAULT_REGION_NAME,
3535
JUMPSTART_LOGGER,
3636
)
37+
from sagemaker.model_metrics import ModelMetrics
38+
from sagemaker.metadata_properties import MetadataProperties
39+
from sagemaker.drift_check_baselines import DriftCheckBaselines
3740
from sagemaker.jumpstart.enums import JumpStartScriptScope
3841
from sagemaker.jumpstart.types import (
3942
JumpStartModelDeployKwargs,
4043
JumpStartModelInitKwargs,
44+
JumpStartModelRegisterKwargs,
4145
)
4246
from sagemaker.jumpstart.utils import (
4347
update_dict_if_key_not_present,
4448
resolve_model_sagemaker_config_field,
49+
verify_model_region_and_return_specs,
4550
)
4651

4752
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
@@ -505,6 +510,87 @@ def get_deploy_kwargs(
505510
return deploy_kwargs
506511

507512

513+
def get_register_kwargs(
514+
model_id: str,
515+
model_version: Optional[str] = None,
516+
region: Optional[str] = None,
517+
tolerate_deprecated_model: Optional[bool] = None,
518+
tolerate_vulnerable_model: Optional[bool] = None,
519+
sagemaker_session: Optional[Any] = None,
520+
supported_content_types: List[str] = None,
521+
response_types: List[str] = None,
522+
inference_instances: Optional[List[str]] = None,
523+
transform_instances: Optional[List[str]] = None,
524+
model_package_group_name: Optional[str] = None,
525+
image_uri: Optional[str] = None,
526+
model_metrics: Optional[ModelMetrics] = None,
527+
metadata_properties: Optional[MetadataProperties] = None,
528+
approval_status: Optional[str] = None,
529+
description: Optional[str] = None,
530+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
531+
customer_metadata_properties: Optional[Dict[str, str]] = None,
532+
validation_specification: Optional[str] = None,
533+
domain: Optional[str] = None,
534+
task: Optional[str] = None,
535+
sample_payload_url: Optional[str] = None,
536+
framework: Optional[str] = None,
537+
framework_version: Optional[str] = None,
538+
nearest_model_name: Optional[str] = None,
539+
data_input_configuration: Optional[str] = None,
540+
skip_model_validation: Optional[str] = None,
541+
) -> JumpStartModelRegisterKwargs:
542+
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""
543+
544+
register_kwargs = JumpStartModelRegisterKwargs(
545+
model_id=model_id,
546+
model_version=model_version,
547+
region=region,
548+
tolerate_deprecated_model=tolerate_deprecated_model,
549+
tolerate_vulnerable_model=tolerate_vulnerable_model,
550+
sagemaker_session=sagemaker_session,
551+
content_types=supported_content_types,
552+
response_types=response_types,
553+
inference_instances=inference_instances,
554+
transform_instances=transform_instances,
555+
model_package_group_name=model_package_group_name,
556+
image_uri=image_uri,
557+
model_metrics=model_metrics,
558+
metadata_properties=metadata_properties,
559+
approval_status=approval_status,
560+
description=description,
561+
drift_check_baselines=drift_check_baselines,
562+
customer_metadata_properties=customer_metadata_properties,
563+
validation_specification=validation_specification,
564+
domain=domain,
565+
task=task,
566+
sample_payload_url=sample_payload_url,
567+
framework=framework,
568+
framework_version=framework_version,
569+
nearest_model_name=nearest_model_name,
570+
data_input_configuration=data_input_configuration,
571+
skip_model_validation=skip_model_validation,
572+
)
573+
574+
model_specs = verify_model_region_and_return_specs(
575+
model_id=model_id,
576+
version=model_version,
577+
region=region,
578+
scope=JumpStartScriptScope.INFERENCE,
579+
sagemaker_session=sagemaker_session,
580+
tolerate_deprecated_model=tolerate_deprecated_model,
581+
tolerate_vulnerable_model=tolerate_vulnerable_model,
582+
)
583+
584+
register_kwargs.content_types = (
585+
register_kwargs.content_types or model_specs.predictor_specs.supported_content_types
586+
)
587+
register_kwargs.response_types = (
588+
register_kwargs.response_types or model_specs.predictor_specs.supported_accept_types
589+
)
590+
591+
return register_kwargs
592+
593+
508594
def get_init_kwargs(
509595
model_id: str,
510596
model_from_estimator: bool = False,

src/sagemaker/jumpstart/model.py

Lines changed: 146 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
"""This module stores JumpStart implementation of Model class."""
1414

1515
from __future__ import absolute_import
16+
import copy
17+
from functools import partial
1618
import re
19+
import types
1720

1821
from typing import Dict, List, Optional, Union
1922
from sagemaker import payloads
@@ -28,16 +31,26 @@
2831
get_default_predictor,
2932
get_deploy_kwargs,
3033
get_init_kwargs,
34+
get_register_kwargs,
3135
)
3236
from sagemaker.jumpstart.types import JumpStartSerializablePayload
3337
from sagemaker.jumpstart.utils import is_valid_model_id
3438
from sagemaker.utils import stringify_object
35-
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
39+
from sagemaker.model import (
40+
MODEL_PACKAGE_ARN_PATTERN,
41+
Model,
42+
MODEL_PACKAGE_VERSIONED_ARN_PATTERN,
43+
ModelPackage,
44+
)
3645
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
3746
from sagemaker.predictor import PredictorBase
3847
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
3948
from sagemaker.session import Session
4049
from sagemaker.workflow.entities import PipelineVariable
50+
from sagemaker.model_metrics import ModelMetrics
51+
from sagemaker.metadata_properties import MetadataProperties
52+
from sagemaker.drift_check_baselines import DriftCheckBaselines
53+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
4154

4255

4356
class JumpStartModel(Model):
@@ -390,29 +403,21 @@ def _create_sagemaker_model(
390403
# inference endpoint.
391404
if self.model_package_arn and not self._model_data_is_set:
392405
# When a ModelPackageArn is provided we just create the Model
393-
match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn)
394-
if match:
395-
model_package_name = match.group(3)
396-
else:
397-
# model_package_arn can be just the name if your account owns the Model Package
398-
model_package_name = self.model_package_arn
399-
container_def = {"ModelPackageName": self.model_package_arn}
400-
401-
if self.env != {}:
402-
container_def["Environment"] = self.env
403-
404-
if self.name is None:
405-
self._base_name = model_package_name
406-
407-
self._set_model_name_if_needed()
408-
409-
self.sagemaker_session.create_model(
410-
self.name,
411-
self.role,
412-
container_def,
406+
model_package = ModelPackage(
407+
role=self.role,
408+
model_data=self.model_data,
409+
model_package_arn=self.model_package_arn,
410+
sagemaker_session=self.sagemaker_session,
411+
predictor_cls=self.predictor_cls,
413412
vpc_config=self.vpc_config,
414-
enable_network_isolation=self.enable_network_isolation(),
413+
)
414+
model_package.name = self.name
415+
model_package._create_sagemaker_model(
416+
instance_type=instance_type,
417+
accelerator_type=accelerator_type,
415418
tags=tags,
419+
serverless_inference_config=serverless_inference_config,
420+
**kwargs,
416421
)
417422
else:
418423
super(JumpStartModel, self)._create_sagemaker_model(
@@ -565,6 +570,125 @@ def deploy(
565570
# If a predictor class was passed, do not mutate predictor
566571
return predictor
567572

573+
def register(
574+
self,
575+
content_types: List[Union[str, PipelineVariable]] = None,
576+
response_types: List[Union[str, PipelineVariable]] = None,
577+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
578+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
579+
model_package_name: Optional[Union[str, PipelineVariable]] = None,
580+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
581+
image_uri: Optional[Union[str, PipelineVariable]] = None,
582+
model_metrics: Optional[ModelMetrics] = None,
583+
metadata_properties: Optional[MetadataProperties] = None,
584+
marketplace_cert: bool = False,
585+
approval_status: Optional[Union[str, PipelineVariable]] = None,
586+
description: Optional[str] = None,
587+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
588+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
589+
validation_specification: Optional[Union[str, PipelineVariable]] = None,
590+
domain: Optional[Union[str, PipelineVariable]] = None,
591+
task: Optional[Union[str, PipelineVariable]] = None,
592+
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
593+
framework: Optional[Union[str, PipelineVariable]] = None,
594+
framework_version: Optional[Union[str, PipelineVariable]] = None,
595+
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
596+
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
597+
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
598+
):
599+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
600+
601+
Args:
602+
content_types (list[str] or list[PipelineVariable]): The supported MIME types
603+
for the input data.
604+
response_types (list[str] or list[PipelineVariable]): The supported MIME types
605+
for the output data.
606+
inference_instances (list[str] or list[PipelineVariable]): A list of the instance
607+
types that are used to generate inferences in real-time (default: None).
608+
transform_instances (list[str] or list[PipelineVariable]): A list of the instance types
609+
on which a transformation job can be run or on which an endpoint can be deployed
610+
(default: None).
611+
model_package_name (str or PipelineVariable): Model Package name, exclusive to
612+
`model_package_group_name`, using `model_package_name` makes the Model Package
613+
un-versioned. Defaults to ``None``.
614+
model_package_group_name (str or PipelineVariable): Model Package Group name,
615+
exclusive to `model_package_name`, using `model_package_group_name` makes the
616+
Model Package versioned. Defaults to ``None``.
617+
image_uri (str or PipelineVariable): Inference image URI for the container. Model class'
618+
self.image will be used if it is None. Defaults to ``None``.
619+
model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``.
620+
metadata_properties (MetadataProperties): MetadataProperties object.
621+
Defaults to ``None``.
622+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
623+
for AWS Marketplace. Defaults to ``False``.
624+
approval_status (str or PipelineVariable): Model Approval Status, values can be
625+
"Approved", "Rejected", or "PendingManualApproval". Defaults to
626+
``PendingManualApproval``.
627+
description (str): Model Package description. Defaults to ``None``.
628+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
629+
customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]):
630+
A dictionary of key-value paired metadata properties (default: None).
631+
domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION",
632+
"NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None).
633+
sample_payload_url (str or PipelineVariable): The S3 path where the sample payload
634+
is stored (default: None).
635+
task (str or PipelineVariable): Task values which are supported by Inference Recommender
636+
are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION",
637+
"IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
638+
framework (str or PipelineVariable): Machine learning framework of the model package
639+
container image (default: None).
640+
framework_version (str or PipelineVariable): Framework version of the Model Package
641+
Container Image (default: None).
642+
nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning
643+
benchmarked by Amazon SageMaker Inference Recommender (default: None).
644+
data_input_configuration (str or PipelineVariable): Input object for the model
645+
(default: None).
646+
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
647+
validation. Values can be "All" or "None" (default: None).
648+
649+
Returns:
650+
A `sagemaker.model.ModelPackage` instance.
651+
"""
652+
653+
register_kwargs = get_register_kwargs(
654+
model_id=self.model_id,
655+
model_version=self.model_version,
656+
region=self.region,
657+
tolerate_deprecated_model=self.tolerate_deprecated_model,
658+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
659+
sagemaker_session=self.sagemaker_session,
660+
supported_content_types=content_types,
661+
response_types=response_types,
662+
inference_instances=inference_instances,
663+
transform_instances=transform_instances,
664+
model_package_name=model_package_name,
665+
model_package_group_name=model_package_group_name,
666+
image_uri=image_uri,
667+
model_metrics=model_metrics,
668+
metadata_properties=metadata_properties,
669+
marketplace_cert=marketplace_cert,
670+
approval_status=approval_status,
671+
description=description,
672+
drift_check_baselines=drift_check_baselines,
673+
customer_metadata_properties=customer_metadata_properties,
674+
validation_specification=validation_specification,
675+
domain=domain,
676+
task=task,
677+
sample_payload_url=sample_payload_url,
678+
framework=framework,
679+
framework_version=framework_version,
680+
nearest_model_name=nearest_model_name,
681+
data_input_configuration=data_input_configuration,
682+
skip_model_validation=skip_model_validation,
683+
)
684+
685+
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
686+
self.model_package_arn = model_package.model_package_arn
687+
688+
model_package.deploy = self.deploy
689+
690+
return model_package
691+
568692
def __str__(self) -> str:
569693
"""Overriding str(*) method to make more human-readable."""
570694
return stringify_object(self)

0 commit comments

Comments
 (0)