Skip to content

Commit d967df9

Browse files
author
Keshav Chandak
committed
feat: Added register step in Jumpstart
1 parent 070dc34 commit d967df9

File tree

6 files changed

+458
-77
lines changed

6 files changed

+458
-77
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,19 @@
3434
JUMPSTART_DEFAULT_REGION_NAME,
3535
JUMPSTART_LOGGER,
3636
)
37+
from sagemaker.model_metrics import ModelMetrics
38+
from sagemaker.metadata_properties import MetadataProperties
39+
from sagemaker.drift_check_baselines import DriftCheckBaselines
3740
from sagemaker.jumpstart.enums import JumpStartScriptScope
3841
from sagemaker.jumpstart.types import (
3942
JumpStartModelDeployKwargs,
4043
JumpStartModelInitKwargs,
44+
JumpStartModelRegisterKwargs,
4145
)
4246
from sagemaker.jumpstart.utils import (
4347
update_dict_if_key_not_present,
4448
resolve_model_sagemaker_config_field,
49+
verify_model_region_and_return_specs,
4550
)
4651

4752
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
@@ -506,6 +511,87 @@ def get_deploy_kwargs(
506511
return deploy_kwargs
507512

508513

514+
def get_register_kwargs(
515+
model_id: str,
516+
model_version: Optional[str] = None,
517+
region: Optional[str] = None,
518+
tolerate_deprecated_model: Optional[bool] = None,
519+
tolerate_vulnerable_model: Optional[bool] = None,
520+
sagemaker_session: Optional[Any] = None,
521+
supported_content_types: List[str] = None,
522+
response_types: List[str] = None,
523+
inference_instances: Optional[List[str]] = None,
524+
transform_instances: Optional[List[str]] = None,
525+
model_package_group_name: Optional[str] = None,
526+
image_uri: Optional[str] = None,
527+
model_metrics: Optional[ModelMetrics] = None,
528+
metadata_properties: Optional[MetadataProperties] = None,
529+
approval_status: Optional[str] = None,
530+
description: Optional[str] = None,
531+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
532+
customer_metadata_properties: Optional[Dict[str, str]] = None,
533+
validation_specification: Optional[str] = None,
534+
domain: Optional[str] = None,
535+
task: Optional[str] = None,
536+
sample_payload_url: Optional[str] = None,
537+
framework: Optional[str] = None,
538+
framework_version: Optional[str] = None,
539+
nearest_model_name: Optional[str] = None,
540+
data_input_configuration: Optional[str] = None,
541+
skip_model_validation: Optional[str] = None,
542+
) -> JumpStartModelRegisterKwargs:
543+
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""
544+
545+
register_kwargs = JumpStartModelRegisterKwargs(
546+
model_id=model_id,
547+
model_version=model_version,
548+
region=region,
549+
tolerate_deprecated_model=tolerate_deprecated_model,
550+
tolerate_vulnerable_model=tolerate_vulnerable_model,
551+
sagemaker_session=sagemaker_session,
552+
content_types=supported_content_types,
553+
response_types=response_types,
554+
inference_instances=inference_instances,
555+
transform_instances=transform_instances,
556+
model_package_group_name=model_package_group_name,
557+
image_uri=image_uri,
558+
model_metrics=model_metrics,
559+
metadata_properties=metadata_properties,
560+
approval_status=approval_status,
561+
description=description,
562+
drift_check_baselines=drift_check_baselines,
563+
customer_metadata_properties=customer_metadata_properties,
564+
validation_specification=validation_specification,
565+
domain=domain,
566+
task=task,
567+
sample_payload_url=sample_payload_url,
568+
framework=framework,
569+
framework_version=framework_version,
570+
nearest_model_name=nearest_model_name,
571+
data_input_configuration=data_input_configuration,
572+
skip_model_validation=skip_model_validation,
573+
)
574+
575+
model_specs = verify_model_region_and_return_specs(
576+
model_id=model_id,
577+
version=model_version,
578+
region=region,
579+
scope=JumpStartScriptScope.INFERENCE,
580+
sagemaker_session=sagemaker_session,
581+
tolerate_deprecated_model=tolerate_deprecated_model,
582+
tolerate_vulnerable_model=tolerate_vulnerable_model,
583+
)
584+
585+
register_kwargs.content_types = (
586+
register_kwargs.content_types or model_specs.predictor_specs.supported_content_types
587+
)
588+
register_kwargs.response_types = (
589+
register_kwargs.response_types or model_specs.predictor_specs.supported_accept_types
590+
)
591+
592+
return register_kwargs
593+
594+
509595
def get_init_kwargs(
510596
model_id: str,
511597
model_from_estimator: bool = False,

src/sagemaker/jumpstart/model.py

Lines changed: 134 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"""This module stores JumpStart implementation of Model class."""
1414

1515
from __future__ import absolute_import
16-
import re
1716

1817
from typing import Dict, List, Optional, Union
1918
from sagemaker import payloads
@@ -28,16 +27,23 @@
2827
get_default_predictor,
2928
get_deploy_kwargs,
3029
get_init_kwargs,
30+
get_register_kwargs,
3131
)
3232
from sagemaker.jumpstart.types import JumpStartSerializablePayload
3333
from sagemaker.jumpstart.utils import is_valid_model_id
3434
from sagemaker.utils import stringify_object
35-
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
35+
from sagemaker.model import (
36+
Model,
37+
ModelPackage,
38+
)
3639
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
3740
from sagemaker.predictor import PredictorBase
3841
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
3942
from sagemaker.session import Session
4043
from sagemaker.workflow.entities import PipelineVariable
44+
from sagemaker.model_metrics import ModelMetrics
45+
from sagemaker.metadata_properties import MetadataProperties
46+
from sagemaker.drift_check_baselines import DriftCheckBaselines
4147

4248

4349
class JumpStartModel(Model):
@@ -309,11 +315,12 @@ def _is_valid_model_id_hook():
309315
self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model
310316
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
311317
self.region = model_init_kwargs.region
312-
self.model_package_arn = model_init_kwargs.model_package_arn
313318
self.sagemaker_session = model_init_kwargs.sagemaker_session
314319

315320
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
316321

322+
self.model_package_arn = model_init_kwargs.model_package_arn
323+
317324
def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]:
318325
"""Returns all example payloads associated with the model.
319326
@@ -390,29 +397,22 @@ def _create_sagemaker_model(
390397
# inference endpoint.
391398
if self.model_package_arn and not self._model_data_is_set:
392399
# When a ModelPackageArn is provided we just create the Model
393-
match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn)
394-
if match:
395-
model_package_name = match.group(3)
396-
else:
397-
# model_package_arn can be just the name if your account owns the Model Package
398-
model_package_name = self.model_package_arn
399-
container_def = {"ModelPackageName": self.model_package_arn}
400-
401-
if self.env != {}:
402-
container_def["Environment"] = self.env
403-
404-
if self.name is None:
405-
self._base_name = model_package_name
406-
407-
self._set_model_name_if_needed()
408-
409-
self.sagemaker_session.create_model(
410-
self.name,
411-
self.role,
412-
container_def,
400+
model_package = ModelPackage(
401+
role=self.role,
402+
model_data=self.model_data,
403+
model_package_arn=self.model_package_arn,
404+
sagemaker_session=self.sagemaker_session,
405+
predictor_cls=self.predictor_cls,
413406
vpc_config=self.vpc_config,
414-
enable_network_isolation=self.enable_network_isolation(),
407+
)
408+
model_package.name = self.name
409+
model_package.env = self.env
410+
model_package._create_sagemaker_model(
411+
instance_type=instance_type,
412+
accelerator_type=accelerator_type,
415413
tags=tags,
414+
serverless_inference_config=serverless_inference_config,
415+
**kwargs,
416416
)
417417
else:
418418
super(JumpStartModel, self)._create_sagemaker_model(
@@ -565,6 +565,116 @@ def deploy(
565565
# If a predictor class was passed, do not mutate predictor
566566
return predictor
567567

568+
def register(
569+
self,
570+
content_types: List[Union[str, PipelineVariable]] = None,
571+
response_types: List[Union[str, PipelineVariable]] = None,
572+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
573+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
574+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
575+
image_uri: Optional[Union[str, PipelineVariable]] = None,
576+
model_metrics: Optional[ModelMetrics] = None,
577+
metadata_properties: Optional[MetadataProperties] = None,
578+
approval_status: Optional[Union[str, PipelineVariable]] = None,
579+
description: Optional[str] = None,
580+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
581+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
582+
validation_specification: Optional[Union[str, PipelineVariable]] = None,
583+
domain: Optional[Union[str, PipelineVariable]] = None,
584+
task: Optional[Union[str, PipelineVariable]] = None,
585+
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
586+
framework: Optional[Union[str, PipelineVariable]] = None,
587+
framework_version: Optional[Union[str, PipelineVariable]] = None,
588+
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
589+
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
590+
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
591+
):
592+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
593+
594+
Args:
595+
content_types (list[str] or list[PipelineVariable]): The supported MIME types
596+
for the input data.
597+
response_types (list[str] or list[PipelineVariable]): The supported MIME types
598+
for the output data.
599+
inference_instances (list[str] or list[PipelineVariable]): A list of the instance
600+
types that are used to generate inferences in real-time (default: None).
601+
transform_instances (list[str] or list[PipelineVariable]): A list of the instance types
602+
on which a transformation job can be run or on which an endpoint can be deployed
603+
(default: None).
604+
model_package_group_name (str or PipelineVariable): Model Package Group name,
605+
exclusive to `model_package_name`, using `model_package_group_name` makes the
606+
Model Package versioned. Defaults to ``None``.
607+
image_uri (str or PipelineVariable): Inference image URI for the container. Model class'
608+
self.image will be used if it is None. Defaults to ``None``.
609+
model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``.
610+
metadata_properties (MetadataProperties): MetadataProperties object.
611+
Defaults to ``None``.
612+
approval_status (str or PipelineVariable): Model Approval Status, values can be
613+
"Approved", "Rejected", or "PendingManualApproval". Defaults to
614+
``PendingManualApproval``.
615+
description (str): Model Package description. Defaults to ``None``.
616+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
617+
customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]):
618+
A dictionary of key-value paired metadata properties (default: None).
619+
domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION",
620+
"NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None).
621+
sample_payload_url (str or PipelineVariable): The S3 path where the sample payload
622+
is stored (default: None).
623+
task (str or PipelineVariable): Task values which are supported by Inference Recommender
624+
are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION",
625+
"IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
626+
framework (str or PipelineVariable): Machine learning framework of the model package
627+
container image (default: None).
628+
framework_version (str or PipelineVariable): Framework version of the Model Package
629+
Container Image (default: None).
630+
nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning
631+
benchmarked by Amazon SageMaker Inference Recommender (default: None).
632+
data_input_configuration (str or PipelineVariable): Input object for the model
633+
(default: None).
634+
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
635+
validation. Values can be "All" or "None" (default: None).
636+
637+
Returns:
638+
A `sagemaker.model.ModelPackage` instance.
639+
"""
640+
641+
register_kwargs = get_register_kwargs(
642+
model_id=self.model_id,
643+
model_version=self.model_version,
644+
region=self.region,
645+
tolerate_deprecated_model=self.tolerate_deprecated_model,
646+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
647+
sagemaker_session=self.sagemaker_session,
648+
supported_content_types=content_types,
649+
response_types=response_types,
650+
inference_instances=inference_instances,
651+
transform_instances=transform_instances,
652+
model_package_group_name=model_package_group_name,
653+
image_uri=image_uri,
654+
model_metrics=model_metrics,
655+
metadata_properties=metadata_properties,
656+
approval_status=approval_status,
657+
description=description,
658+
drift_check_baselines=drift_check_baselines,
659+
customer_metadata_properties=customer_metadata_properties,
660+
validation_specification=validation_specification,
661+
domain=domain,
662+
task=task,
663+
sample_payload_url=sample_payload_url,
664+
framework=framework,
665+
framework_version=framework_version,
666+
nearest_model_name=nearest_model_name,
667+
data_input_configuration=data_input_configuration,
668+
skip_model_validation=skip_model_validation,
669+
)
670+
671+
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
672+
self.model_package_arn = model_package.model_package_arn
673+
674+
model_package.deploy = self.deploy
675+
676+
return model_package
677+
568678
def __str__(self) -> str:
569679
"""Overriding str(*) method to make more human-readable."""
570680
return stringify_object(self)

0 commit comments

Comments
 (0)