|
13 | 13 | """This module stores JumpStart implementation of Model class."""
|
14 | 14 |
|
15 | 15 | from __future__ import absolute_import
|
16 |
| -import re |
17 | 16 |
|
18 | 17 | from typing import Dict, List, Optional, Union
|
19 | 18 | from sagemaker import payloads
|
|
28 | 27 | get_default_predictor,
|
29 | 28 | get_deploy_kwargs,
|
30 | 29 | get_init_kwargs,
|
| 30 | + get_register_kwargs, |
31 | 31 | )
|
32 | 32 | from sagemaker.jumpstart.types import JumpStartSerializablePayload
|
33 | 33 | from sagemaker.jumpstart.utils import is_valid_model_id
|
34 | 34 | from sagemaker.utils import stringify_object
|
35 |
| -from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model |
| 35 | +from sagemaker.model import ( |
| 36 | + Model, |
| 37 | + ModelPackage, |
| 38 | +) |
36 | 39 | from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
|
37 | 40 | from sagemaker.predictor import PredictorBase
|
38 | 41 | from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
|
39 | 42 | from sagemaker.session import Session
|
40 | 43 | from sagemaker.workflow.entities import PipelineVariable
|
| 44 | +from sagemaker.model_metrics import ModelMetrics |
| 45 | +from sagemaker.metadata_properties import MetadataProperties |
| 46 | +from sagemaker.drift_check_baselines import DriftCheckBaselines |
41 | 47 |
|
42 | 48 |
|
43 | 49 | class JumpStartModel(Model):
|
@@ -390,29 +396,21 @@ def _create_sagemaker_model(
|
390 | 396 | # inference endpoint.
|
391 | 397 | if self.model_package_arn and not self._model_data_is_set:
|
392 | 398 | # When a ModelPackageArn is provided we just create the Model
|
393 |
| - match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn) |
394 |
| - if match: |
395 |
| - model_package_name = match.group(3) |
396 |
| - else: |
397 |
| - # model_package_arn can be just the name if your account owns the Model Package |
398 |
| - model_package_name = self.model_package_arn |
399 |
| - container_def = {"ModelPackageName": self.model_package_arn} |
400 |
| - |
401 |
| - if self.env != {}: |
402 |
| - container_def["Environment"] = self.env |
403 |
| - |
404 |
| - if self.name is None: |
405 |
| - self._base_name = model_package_name |
406 |
| - |
407 |
| - self._set_model_name_if_needed() |
408 |
| - |
409 |
| - self.sagemaker_session.create_model( |
410 |
| - self.name, |
411 |
| - self.role, |
412 |
| - container_def, |
| 399 | + model_package = ModelPackage( |
| 400 | + role=self.role, |
| 401 | + model_data=self.model_data, |
| 402 | + model_package_arn=self.model_package_arn, |
| 403 | + sagemaker_session=self.sagemaker_session, |
| 404 | + predictor_cls=self.predictor_cls, |
413 | 405 | vpc_config=self.vpc_config,
|
414 |
| - enable_network_isolation=self.enable_network_isolation(), |
| 406 | + ) |
| 407 | + model_package.name = self.name |
| 408 | + model_package._create_sagemaker_model( |
| 409 | + instance_type=instance_type, |
| 410 | + accelerator_type=accelerator_type, |
415 | 411 | tags=tags,
|
| 412 | + serverless_inference_config=serverless_inference_config, |
| 413 | + **kwargs, |
416 | 414 | )
|
417 | 415 | else:
|
418 | 416 | super(JumpStartModel, self)._create_sagemaker_model(
|
@@ -565,6 +563,116 @@ def deploy(
|
565 | 563 | # If a predictor class was passed, do not mutate predictor
|
566 | 564 | return predictor
|
567 | 565 |
|
| 566 | + def register( |
| 567 | + self, |
| 568 | + content_types: List[Union[str, PipelineVariable]] = None, |
| 569 | + response_types: List[Union[str, PipelineVariable]] = None, |
| 570 | + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 571 | + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 572 | + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, |
| 573 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 574 | + model_metrics: Optional[ModelMetrics] = None, |
| 575 | + metadata_properties: Optional[MetadataProperties] = None, |
| 576 | + approval_status: Optional[Union[str, PipelineVariable]] = None, |
| 577 | + description: Optional[str] = None, |
| 578 | + drift_check_baselines: Optional[DriftCheckBaselines] = None, |
| 579 | + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 580 | + validation_specification: Optional[Union[str, PipelineVariable]] = None, |
| 581 | + domain: Optional[Union[str, PipelineVariable]] = None, |
| 582 | + task: Optional[Union[str, PipelineVariable]] = None, |
| 583 | + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, |
| 584 | + framework: Optional[Union[str, PipelineVariable]] = None, |
| 585 | + framework_version: Optional[Union[str, PipelineVariable]] = None, |
| 586 | + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, |
| 587 | + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, |
| 588 | + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, |
| 589 | + ): |
| 590 | + """Creates a model package for creating SageMaker models or listing on Marketplace. |
| 591 | +
|
| 592 | + Args: |
| 593 | + content_types (list[str] or list[PipelineVariable]): The supported MIME types |
| 594 | + for the input data. |
| 595 | + response_types (list[str] or list[PipelineVariable]): The supported MIME types |
| 596 | + for the output data. |
| 597 | + inference_instances (list[str] or list[PipelineVariable]): A list of the instance |
| 598 | + types that are used to generate inferences in real-time (default: None). |
| 599 | + transform_instances (list[str] or list[PipelineVariable]): A list of the instance types |
| 600 | + on which a transformation job can be run or on which an endpoint can be deployed |
| 601 | + (default: None). |
| 602 | + model_package_group_name (str or PipelineVariable): Model Package Group name, |
| 603 | + exclusive to `model_package_name`, using `model_package_group_name` makes the |
| 604 | + Model Package versioned. Defaults to ``None``. |
| 605 | + image_uri (str or PipelineVariable): Inference image URI for the container. Model class' |
| 606 | + self.image will be used if it is None. Defaults to ``None``. |
| 607 | + model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``. |
| 608 | + metadata_properties (MetadataProperties): MetadataProperties object. |
| 609 | + Defaults to ``None``. |
| 610 | + approval_status (str or PipelineVariable): Model Approval Status, values can be |
| 611 | + "Approved", "Rejected", or "PendingManualApproval". Defaults to |
| 612 | + ``PendingManualApproval``. |
| 613 | + description (str): Model Package description. Defaults to ``None``. |
| 614 | + drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). |
| 615 | + customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]): |
| 616 | + A dictionary of key-value paired metadata properties (default: None). |
| 617 | + domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION", |
| 618 | + "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). |
| 619 | + sample_payload_url (str or PipelineVariable): The S3 path where the sample payload |
| 620 | + is stored (default: None). |
| 621 | + task (str or PipelineVariable): Task values which are supported by Inference Recommender |
| 622 | + are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", |
| 623 | + "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). |
| 624 | + framework (str or PipelineVariable): Machine learning framework of the model package |
| 625 | + container image (default: None). |
| 626 | + framework_version (str or PipelineVariable): Framework version of the Model Package |
| 627 | + Container Image (default: None). |
| 628 | + nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning |
| 629 | + benchmarked by Amazon SageMaker Inference Recommender (default: None). |
| 630 | + data_input_configuration (str or PipelineVariable): Input object for the model |
| 631 | + (default: None). |
| 632 | + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model |
| 633 | + validation. Values can be "All" or "None" (default: None). |
| 634 | +
|
| 635 | + Returns: |
| 636 | + A `sagemaker.model.ModelPackage` instance. |
| 637 | + """ |
| 638 | + |
| 639 | + register_kwargs = get_register_kwargs( |
| 640 | + model_id=self.model_id, |
| 641 | + model_version=self.model_version, |
| 642 | + region=self.region, |
| 643 | + tolerate_deprecated_model=self.tolerate_deprecated_model, |
| 644 | + tolerate_vulnerable_model=self.tolerate_vulnerable_model, |
| 645 | + sagemaker_session=self.sagemaker_session, |
| 646 | + supported_content_types=content_types, |
| 647 | + response_types=response_types, |
| 648 | + inference_instances=inference_instances, |
| 649 | + transform_instances=transform_instances, |
| 650 | + model_package_group_name=model_package_group_name, |
| 651 | + image_uri=image_uri, |
| 652 | + model_metrics=model_metrics, |
| 653 | + metadata_properties=metadata_properties, |
| 654 | + approval_status=approval_status, |
| 655 | + description=description, |
| 656 | + drift_check_baselines=drift_check_baselines, |
| 657 | + customer_metadata_properties=customer_metadata_properties, |
| 658 | + validation_specification=validation_specification, |
| 659 | + domain=domain, |
| 660 | + task=task, |
| 661 | + sample_payload_url=sample_payload_url, |
| 662 | + framework=framework, |
| 663 | + framework_version=framework_version, |
| 664 | + nearest_model_name=nearest_model_name, |
| 665 | + data_input_configuration=data_input_configuration, |
| 666 | + skip_model_validation=skip_model_validation, |
| 667 | + ) |
| 668 | + |
| 669 | + model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) |
| 670 | + self.model_package_arn = model_package.model_package_arn |
| 671 | + |
| 672 | + model_package.deploy = self.deploy |
| 673 | + |
| 674 | + return model_package |
| 675 | + |
568 | 676 | def __str__(self) -> str:
|
569 | 677 | """Overriding str(*) method to make more human-readable."""
|
570 | 678 | return stringify_object(self)
|
0 commit comments