|
13 | 13 | """This module stores JumpStart implementation of Model class."""
|
14 | 14 |
|
15 | 15 | from __future__ import absolute_import
|
| 16 | +import copy |
| 17 | +from functools import partial |
16 | 18 | import re
|
| 19 | +import types |
17 | 20 |
|
18 | 21 | from typing import Dict, List, Optional, Union
|
19 | 22 | from sagemaker import payloads
|
|
28 | 31 | get_default_predictor,
|
29 | 32 | get_deploy_kwargs,
|
30 | 33 | get_init_kwargs,
|
| 34 | + get_register_kwargs, |
31 | 35 | )
|
32 | 36 | from sagemaker.jumpstart.types import JumpStartSerializablePayload
|
33 | 37 | from sagemaker.jumpstart.utils import is_valid_model_id
|
34 | 38 | from sagemaker.utils import stringify_object
|
35 |
| -from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model |
| 39 | +from sagemaker.model import ( |
| 40 | + MODEL_PACKAGE_ARN_PATTERN, |
| 41 | + Model, |
| 42 | + MODEL_PACKAGE_VERSIONED_ARN_PATTERN, |
| 43 | + ModelPackage, |
| 44 | +) |
36 | 45 | from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
|
37 | 46 | from sagemaker.predictor import PredictorBase
|
38 | 47 | from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
|
39 | 48 | from sagemaker.session import Session
|
40 | 49 | from sagemaker.workflow.entities import PipelineVariable
|
| 50 | +from sagemaker.model_metrics import ModelMetrics |
| 51 | +from sagemaker.metadata_properties import MetadataProperties |
| 52 | +from sagemaker.drift_check_baselines import DriftCheckBaselines |
| 53 | +from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum |
41 | 54 |
|
42 | 55 |
|
43 | 56 | class JumpStartModel(Model):
|
@@ -390,29 +403,21 @@ def _create_sagemaker_model(
|
390 | 403 | # inference endpoint.
|
391 | 404 | if self.model_package_arn and not self._model_data_is_set:
|
392 | 405 | # When a ModelPackageArn is provided we just create the Model
|
393 |
| - match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn) |
394 |
| - if match: |
395 |
| - model_package_name = match.group(3) |
396 |
| - else: |
397 |
| - # model_package_arn can be just the name if your account owns the Model Package |
398 |
| - model_package_name = self.model_package_arn |
399 |
| - container_def = {"ModelPackageName": self.model_package_arn} |
400 |
| - |
401 |
| - if self.env != {}: |
402 |
| - container_def["Environment"] = self.env |
403 |
| - |
404 |
| - if self.name is None: |
405 |
| - self._base_name = model_package_name |
406 |
| - |
407 |
| - self._set_model_name_if_needed() |
408 |
| - |
409 |
| - self.sagemaker_session.create_model( |
410 |
| - self.name, |
411 |
| - self.role, |
412 |
| - container_def, |
| 406 | + model_package = ModelPackage( |
| 407 | + role=self.role, |
| 408 | + model_data=self.model_data, |
| 409 | + model_package_arn=self.model_package_arn, |
| 410 | + sagemaker_session=self.sagemaker_session, |
| 411 | + predictor_cls=self.predictor_cls, |
413 | 412 | vpc_config=self.vpc_config,
|
414 |
| - enable_network_isolation=self.enable_network_isolation(), |
| 413 | + ) |
| 414 | + model_package.name = self.name |
| 415 | + model_package._create_sagemaker_model( |
| 416 | + instance_type=instance_type, |
| 417 | + accelerator_type=accelerator_type, |
415 | 418 | tags=tags,
|
| 419 | + serverless_inference_config=serverless_inference_config, |
| 420 | + **kwargs, |
416 | 421 | )
|
417 | 422 | else:
|
418 | 423 | super(JumpStartModel, self)._create_sagemaker_model(
|
@@ -565,6 +570,125 @@ def deploy(
|
565 | 570 | # If a predictor class was passed, do not mutate predictor
|
566 | 571 | return predictor
|
567 | 572 |
|
| 573 | + def register( |
| 574 | + self, |
| 575 | + content_types: List[Union[str, PipelineVariable]] = None, |
| 576 | + response_types: List[Union[str, PipelineVariable]] = None, |
| 577 | + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 578 | + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 579 | + model_package_name: Optional[Union[str, PipelineVariable]] = None, |
| 580 | + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, |
| 581 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 582 | + model_metrics: Optional[ModelMetrics] = None, |
| 583 | + metadata_properties: Optional[MetadataProperties] = None, |
| 584 | + marketplace_cert: bool = False, |
| 585 | + approval_status: Optional[Union[str, PipelineVariable]] = None, |
| 586 | + description: Optional[str] = None, |
| 587 | + drift_check_baselines: Optional[DriftCheckBaselines] = None, |
| 588 | + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 589 | + validation_specification: Optional[Union[str, PipelineVariable]] = None, |
| 590 | + domain: Optional[Union[str, PipelineVariable]] = None, |
| 591 | + task: Optional[Union[str, PipelineVariable]] = None, |
| 592 | + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, |
| 593 | + framework: Optional[Union[str, PipelineVariable]] = None, |
| 594 | + framework_version: Optional[Union[str, PipelineVariable]] = None, |
| 595 | + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, |
| 596 | + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, |
| 597 | + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, |
| 598 | + ): |
| 599 | + """Creates a model package for creating SageMaker models or listing on Marketplace. |
| 600 | +
|
| 601 | + Args: |
| 602 | + content_types (list[str] or list[PipelineVariable]): The supported MIME types |
| 603 | + for the input data. |
| 604 | + response_types (list[str] or list[PipelineVariable]): The supported MIME types |
| 605 | + for the output data. |
| 606 | + inference_instances (list[str] or list[PipelineVariable]): A list of the instance |
| 607 | + types that are used to generate inferences in real-time (default: None). |
| 608 | + transform_instances (list[str] or list[PipelineVariable]): A list of the instance types |
| 609 | + on which a transformation job can be run or on which an endpoint can be deployed |
| 610 | + (default: None). |
| 611 | + model_package_name (str or PipelineVariable): Model Package name, exclusive to |
| 612 | + `model_package_group_name`, using `model_package_name` makes the Model Package |
| 613 | + un-versioned. Defaults to ``None``. |
| 614 | + model_package_group_name (str or PipelineVariable): Model Package Group name, |
| 615 | + exclusive to `model_package_name`, using `model_package_group_name` makes the |
| 616 | + Model Package versioned. Defaults to ``None``. |
| 617 | + image_uri (str or PipelineVariable): Inference image URI for the container. Model class' |
| 618 | + self.image will be used if it is None. Defaults to ``None``. |
| 619 | + model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``. |
| 620 | + metadata_properties (MetadataProperties): MetadataProperties object. |
| 621 | + Defaults to ``None``. |
| 622 | + marketplace_cert (bool): A boolean value indicating if the Model Package is certified |
| 623 | + for AWS Marketplace. Defaults to ``False``. |
| 624 | + approval_status (str or PipelineVariable): Model Approval Status, values can be |
| 625 | + "Approved", "Rejected", or "PendingManualApproval". Defaults to |
| 626 | + ``PendingManualApproval``. |
| 627 | + description (str): Model Package description. Defaults to ``None``. |
| 628 | + drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). |
| 629 | + customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]): |
| 630 | + A dictionary of key-value paired metadata properties (default: None). |
| 631 | + domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION", |
| 632 | + "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). |
| 633 | + sample_payload_url (str or PipelineVariable): The S3 path where the sample payload |
| 634 | + is stored (default: None). |
| 635 | + task (str or PipelineVariable): Task values which are supported by Inference Recommender |
| 636 | + are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", |
| 637 | + "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). |
| 638 | + framework (str or PipelineVariable): Machine learning framework of the model package |
| 639 | + container image (default: None). |
| 640 | + framework_version (str or PipelineVariable): Framework version of the Model Package |
| 641 | + Container Image (default: None). |
| 642 | + nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning |
| 643 | + benchmarked by Amazon SageMaker Inference Recommender (default: None). |
| 644 | + data_input_configuration (str or PipelineVariable): Input object for the model |
| 645 | + (default: None). |
| 646 | + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model |
| 647 | + validation. Values can be "All" or "None" (default: None). |
| 648 | +
|
| 649 | + Returns: |
| 650 | + A `sagemaker.model.ModelPackage` instance. |
| 651 | + """ |
| 652 | + |
| 653 | + register_kwargs = get_register_kwargs( |
| 654 | + model_id=self.model_id, |
| 655 | + model_version=self.model_version, |
| 656 | + region=self.region, |
| 657 | + tolerate_deprecated_model=self.tolerate_deprecated_model, |
| 658 | + tolerate_vulnerable_model=self.tolerate_vulnerable_model, |
| 659 | + sagemaker_session=self.sagemaker_session, |
| 660 | + supported_content_types=content_types, |
| 661 | + response_types=response_types, |
| 662 | + inference_instances=inference_instances, |
| 663 | + transform_instances=transform_instances, |
| 664 | + model_package_name=model_package_name, |
| 665 | + model_package_group_name=model_package_group_name, |
| 666 | + image_uri=image_uri, |
| 667 | + model_metrics=model_metrics, |
| 668 | + metadata_properties=metadata_properties, |
| 669 | + marketplace_cert=marketplace_cert, |
| 670 | + approval_status=approval_status, |
| 671 | + description=description, |
| 672 | + drift_check_baselines=drift_check_baselines, |
| 673 | + customer_metadata_properties=customer_metadata_properties, |
| 674 | + validation_specification=validation_specification, |
| 675 | + domain=domain, |
| 676 | + task=task, |
| 677 | + sample_payload_url=sample_payload_url, |
| 678 | + framework=framework, |
| 679 | + framework_version=framework_version, |
| 680 | + nearest_model_name=nearest_model_name, |
| 681 | + data_input_configuration=data_input_configuration, |
| 682 | + skip_model_validation=skip_model_validation, |
| 683 | + ) |
| 684 | + |
| 685 | + model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) |
| 686 | + self.model_package_arn = model_package.model_package_arn |
| 687 | + |
| 688 | + model_package.deploy = self.deploy |
| 689 | + |
| 690 | + return model_package |
| 691 | + |
568 | 692 | def __str__(self) -> str:
|
569 | 693 | """Overriding str(*) method to make more human-readable."""
|
570 | 694 | return stringify_object(self)
|
0 commit comments