Skip to content

Commit 75c5fe6

Browse files
author
Keshav Chandak
committed
Added register step in Jumpstart
1 parent 070dc34 commit 75c5fe6

File tree

6 files changed

+456
-75
lines changed

6 files changed

+456
-75
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,19 @@
3434
JUMPSTART_DEFAULT_REGION_NAME,
3535
JUMPSTART_LOGGER,
3636
)
37+
from sagemaker.model_metrics import ModelMetrics
38+
from sagemaker.metadata_properties import MetadataProperties
39+
from sagemaker.drift_check_baselines import DriftCheckBaselines
3740
from sagemaker.jumpstart.enums import JumpStartScriptScope
3841
from sagemaker.jumpstart.types import (
3942
JumpStartModelDeployKwargs,
4043
JumpStartModelInitKwargs,
44+
JumpStartModelRegisterKwargs,
4145
)
4246
from sagemaker.jumpstart.utils import (
4347
update_dict_if_key_not_present,
4448
resolve_model_sagemaker_config_field,
49+
verify_model_region_and_return_specs,
4550
)
4651

4752
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
@@ -506,6 +511,87 @@ def get_deploy_kwargs(
506511
return deploy_kwargs
507512

508513

514+
def get_register_kwargs(
515+
model_id: str,
516+
model_version: Optional[str] = None,
517+
region: Optional[str] = None,
518+
tolerate_deprecated_model: Optional[bool] = None,
519+
tolerate_vulnerable_model: Optional[bool] = None,
520+
sagemaker_session: Optional[Any] = None,
521+
supported_content_types: List[str] = None,
522+
response_types: List[str] = None,
523+
inference_instances: Optional[List[str]] = None,
524+
transform_instances: Optional[List[str]] = None,
525+
model_package_group_name: Optional[str] = None,
526+
image_uri: Optional[str] = None,
527+
model_metrics: Optional[ModelMetrics] = None,
528+
metadata_properties: Optional[MetadataProperties] = None,
529+
approval_status: Optional[str] = None,
530+
description: Optional[str] = None,
531+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
532+
customer_metadata_properties: Optional[Dict[str, str]] = None,
533+
validation_specification: Optional[str] = None,
534+
domain: Optional[str] = None,
535+
task: Optional[str] = None,
536+
sample_payload_url: Optional[str] = None,
537+
framework: Optional[str] = None,
538+
framework_version: Optional[str] = None,
539+
nearest_model_name: Optional[str] = None,
540+
data_input_configuration: Optional[str] = None,
541+
skip_model_validation: Optional[str] = None,
542+
) -> JumpStartModelRegisterKwargs:
543+
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""
544+
545+
register_kwargs = JumpStartModelRegisterKwargs(
546+
model_id=model_id,
547+
model_version=model_version,
548+
region=region,
549+
tolerate_deprecated_model=tolerate_deprecated_model,
550+
tolerate_vulnerable_model=tolerate_vulnerable_model,
551+
sagemaker_session=sagemaker_session,
552+
content_types=supported_content_types,
553+
response_types=response_types,
554+
inference_instances=inference_instances,
555+
transform_instances=transform_instances,
556+
model_package_group_name=model_package_group_name,
557+
image_uri=image_uri,
558+
model_metrics=model_metrics,
559+
metadata_properties=metadata_properties,
560+
approval_status=approval_status,
561+
description=description,
562+
drift_check_baselines=drift_check_baselines,
563+
customer_metadata_properties=customer_metadata_properties,
564+
validation_specification=validation_specification,
565+
domain=domain,
566+
task=task,
567+
sample_payload_url=sample_payload_url,
568+
framework=framework,
569+
framework_version=framework_version,
570+
nearest_model_name=nearest_model_name,
571+
data_input_configuration=data_input_configuration,
572+
skip_model_validation=skip_model_validation,
573+
)
574+
575+
model_specs = verify_model_region_and_return_specs(
576+
model_id=model_id,
577+
version=model_version,
578+
region=region,
579+
scope=JumpStartScriptScope.INFERENCE,
580+
sagemaker_session=sagemaker_session,
581+
tolerate_deprecated_model=tolerate_deprecated_model,
582+
tolerate_vulnerable_model=tolerate_vulnerable_model,
583+
)
584+
585+
register_kwargs.content_types = (
586+
register_kwargs.content_types or model_specs.predictor_specs.supported_content_types
587+
)
588+
register_kwargs.response_types = (
589+
register_kwargs.response_types or model_specs.predictor_specs.supported_accept_types
590+
)
591+
592+
return register_kwargs
593+
594+
509595
def get_init_kwargs(
510596
model_id: str,
511597
model_from_estimator: bool = False,

src/sagemaker/jumpstart/model.py

Lines changed: 131 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"""This module stores JumpStart implementation of Model class."""
1414

1515
from __future__ import absolute_import
16-
import re
1716

1817
from typing import Dict, List, Optional, Union
1918
from sagemaker import payloads
@@ -28,16 +27,23 @@
2827
get_default_predictor,
2928
get_deploy_kwargs,
3029
get_init_kwargs,
30+
get_register_kwargs,
3131
)
3232
from sagemaker.jumpstart.types import JumpStartSerializablePayload
3333
from sagemaker.jumpstart.utils import is_valid_model_id
3434
from sagemaker.utils import stringify_object
35-
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
35+
from sagemaker.model import (
36+
Model,
37+
ModelPackage,
38+
)
3639
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
3740
from sagemaker.predictor import PredictorBase
3841
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
3942
from sagemaker.session import Session
4043
from sagemaker.workflow.entities import PipelineVariable
44+
from sagemaker.model_metrics import ModelMetrics
45+
from sagemaker.metadata_properties import MetadataProperties
46+
from sagemaker.drift_check_baselines import DriftCheckBaselines
4147

4248

4349
class JumpStartModel(Model):
@@ -390,29 +396,21 @@ def _create_sagemaker_model(
390396
# inference endpoint.
391397
if self.model_package_arn and not self._model_data_is_set:
392398
# When a ModelPackageArn is provided we just create the Model
393-
match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn)
394-
if match:
395-
model_package_name = match.group(3)
396-
else:
397-
# model_package_arn can be just the name if your account owns the Model Package
398-
model_package_name = self.model_package_arn
399-
container_def = {"ModelPackageName": self.model_package_arn}
400-
401-
if self.env != {}:
402-
container_def["Environment"] = self.env
403-
404-
if self.name is None:
405-
self._base_name = model_package_name
406-
407-
self._set_model_name_if_needed()
408-
409-
self.sagemaker_session.create_model(
410-
self.name,
411-
self.role,
412-
container_def,
399+
model_package = ModelPackage(
400+
role=self.role,
401+
model_data=self.model_data,
402+
model_package_arn=self.model_package_arn,
403+
sagemaker_session=self.sagemaker_session,
404+
predictor_cls=self.predictor_cls,
413405
vpc_config=self.vpc_config,
414-
enable_network_isolation=self.enable_network_isolation(),
406+
)
407+
model_package.name = self.name
408+
model_package._create_sagemaker_model(
409+
instance_type=instance_type,
410+
accelerator_type=accelerator_type,
415411
tags=tags,
412+
serverless_inference_config=serverless_inference_config,
413+
**kwargs,
416414
)
417415
else:
418416
super(JumpStartModel, self)._create_sagemaker_model(
@@ -565,6 +563,116 @@ def deploy(
565563
# If a predictor class was passed, do not mutate predictor
566564
return predictor
567565

566+
def register(
567+
self,
568+
content_types: List[Union[str, PipelineVariable]] = None,
569+
response_types: List[Union[str, PipelineVariable]] = None,
570+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
571+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
572+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
573+
image_uri: Optional[Union[str, PipelineVariable]] = None,
574+
model_metrics: Optional[ModelMetrics] = None,
575+
metadata_properties: Optional[MetadataProperties] = None,
576+
approval_status: Optional[Union[str, PipelineVariable]] = None,
577+
description: Optional[str] = None,
578+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
579+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
580+
validation_specification: Optional[Union[str, PipelineVariable]] = None,
581+
domain: Optional[Union[str, PipelineVariable]] = None,
582+
task: Optional[Union[str, PipelineVariable]] = None,
583+
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
584+
framework: Optional[Union[str, PipelineVariable]] = None,
585+
framework_version: Optional[Union[str, PipelineVariable]] = None,
586+
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
587+
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
588+
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
589+
):
590+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
591+
592+
Args:
593+
content_types (list[str] or list[PipelineVariable]): The supported MIME types
594+
for the input data.
595+
response_types (list[str] or list[PipelineVariable]): The supported MIME types
596+
for the output data.
597+
inference_instances (list[str] or list[PipelineVariable]): A list of the instance
598+
types that are used to generate inferences in real-time (default: None).
599+
transform_instances (list[str] or list[PipelineVariable]): A list of the instance types
600+
on which a transformation job can be run or on which an endpoint can be deployed
601+
(default: None).
602+
model_package_group_name (str or PipelineVariable): Model Package Group name,
603+
exclusive to `model_package_name`, using `model_package_group_name` makes the
604+
Model Package versioned. Defaults to ``None``.
605+
image_uri (str or PipelineVariable): Inference image URI for the container. Model class'
606+
self.image will be used if it is None. Defaults to ``None``.
607+
model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``.
608+
metadata_properties (MetadataProperties): MetadataProperties object.
609+
Defaults to ``None``.
610+
approval_status (str or PipelineVariable): Model Approval Status, values can be
611+
"Approved", "Rejected", or "PendingManualApproval". Defaults to
612+
``PendingManualApproval``.
613+
description (str): Model Package description. Defaults to ``None``.
614+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
615+
customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]):
616+
A dictionary of key-value paired metadata properties (default: None).
617+
domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION",
618+
"NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None).
619+
sample_payload_url (str or PipelineVariable): The S3 path where the sample payload
620+
is stored (default: None).
621+
task (str or PipelineVariable): Task values which are supported by Inference Recommender
622+
are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION",
623+
"IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
624+
framework (str or PipelineVariable): Machine learning framework of the model package
625+
container image (default: None).
626+
framework_version (str or PipelineVariable): Framework version of the Model Package
627+
Container Image (default: None).
628+
nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning
629+
benchmarked by Amazon SageMaker Inference Recommender (default: None).
630+
data_input_configuration (str or PipelineVariable): Input object for the model
631+
(default: None).
632+
skip_model_validation (str or PipelineVariable): Indicates if you want to skip model
633+
validation. Values can be "All" or "None" (default: None).
634+
635+
Returns:
636+
A `sagemaker.model.ModelPackage` instance.
637+
"""
638+
639+
register_kwargs = get_register_kwargs(
640+
model_id=self.model_id,
641+
model_version=self.model_version,
642+
region=self.region,
643+
tolerate_deprecated_model=self.tolerate_deprecated_model,
644+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
645+
sagemaker_session=self.sagemaker_session,
646+
supported_content_types=content_types,
647+
response_types=response_types,
648+
inference_instances=inference_instances,
649+
transform_instances=transform_instances,
650+
model_package_group_name=model_package_group_name,
651+
image_uri=image_uri,
652+
model_metrics=model_metrics,
653+
metadata_properties=metadata_properties,
654+
approval_status=approval_status,
655+
description=description,
656+
drift_check_baselines=drift_check_baselines,
657+
customer_metadata_properties=customer_metadata_properties,
658+
validation_specification=validation_specification,
659+
domain=domain,
660+
task=task,
661+
sample_payload_url=sample_payload_url,
662+
framework=framework,
663+
framework_version=framework_version,
664+
nearest_model_name=nearest_model_name,
665+
data_input_configuration=data_input_configuration,
666+
skip_model_validation=skip_model_validation,
667+
)
668+
669+
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
670+
self.model_package_arn = model_package.model_package_arn
671+
672+
model_package.deploy = self.deploy
673+
674+
return model_package
675+
568676
def __str__(self) -> str:
569677
"""Overriding str(*) method to make more human-readable."""
570678
return stringify_object(self)

0 commit comments

Comments
 (0)