Skip to content

feature: add 'Domain' property to RegisterModel step #3118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,7 @@ def register(
model_name=None,
drift_check_baselines=None,
customer_metadata_properties=None,
domain=None,
**kwargs,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -1311,6 +1312,8 @@ def register(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true? I thought it was arbitrary string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is true. It has been verified from the CLI as well.

"MACHINE_LEARNING" (default: None).
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
``create_model()`` to accept ``**kwargs`` to customize model creation during
deploy. For more, see the implementation docs.
Expand Down Expand Up @@ -1342,6 +1345,7 @@ def register(
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
domain=domain,
)

@property
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def register(
approval_status=None,
description=None,
drift_check_baselines=None,
domain=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -331,6 +332,8 @@ def register(
or "PendingManualApproval". Defaults to ``PendingManualApproval``.
description (str): Model Package description. Defaults to ``None``.
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -359,6 +362,7 @@ def register(
approval_status,
description,
drift_check_baselines=drift_check_baselines,
domain=domain,
)

def prepare_container_def(
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def register(
drift_check_baselines=None,
customer_metadata_properties=None,
validation_specification=None,
domain=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -336,6 +337,8 @@ def register(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -365,6 +368,7 @@ def register(
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
validation_specification=validation_specification,
domain=domain,
)
model_package = self.sagemaker_session.create_model_package_from_containers(
**model_pkg_args
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def register(
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
domain=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -185,6 +186,8 @@ def register(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -214,6 +217,7 @@ def register(
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
domain=domain,
)

def prepare_container_def(
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def register(
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
domain=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -186,6 +187,8 @@ def register(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -215,6 +218,7 @@ def register(
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
domain=domain,
)

def prepare_container_def(
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2803,6 +2803,7 @@ def create_model_package_from_containers(
drift_check_baselines=None,
customer_metadata_properties=None,
validation_specification=None,
domain=None,
):
"""Get request dictionary for CreateModelPackage API.

Expand Down Expand Up @@ -2830,6 +2831,8 @@ def create_model_package_from_containers(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).
"""

model_pkg_request = get_create_model_package_request(
Expand All @@ -2848,6 +2851,7 @@ def create_model_package_from_containers(
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
validation_specification=validation_specification,
domain=domain,
)

def submit(request):
Expand Down Expand Up @@ -4218,6 +4222,7 @@ def get_model_package_args(
drift_check_baselines=None,
customer_metadata_properties=None,
validation_specification=None,
domain=None,
):
"""Get arguments for create_model_package method.

Expand Down Expand Up @@ -4248,6 +4253,8 @@ def get_model_package_args(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).
Returns:
dict: A dictionary of method argument names and values.
"""
Expand Down Expand Up @@ -4289,6 +4296,8 @@ def get_model_package_args(
model_package_args["customer_metadata_properties"] = customer_metadata_properties
if validation_specification is not None:
model_package_args["validation_specification"] = validation_specification
if domain is not None:
model_package_args["domain"] = domain
return model_package_args


Expand All @@ -4309,6 +4318,7 @@ def get_create_model_package_request(
drift_check_baselines=None,
customer_metadata_properties=None,
validation_specification=None,
domain=None,
):
"""Get request dictionary for CreateModelPackage API.

Expand Down Expand Up @@ -4337,6 +4347,8 @@ def get_create_model_package_request(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).
"""

if all([model_package_name, model_package_group_name]):
Expand All @@ -4362,6 +4374,8 @@ def get_create_model_package_request(
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
if validation_specification:
request_dict["ValidationSpecification"] = validation_specification
if domain is not None:
request_dict["Domain"] = domain
if containers is not None:
if not all([content_types, response_types, inference_instances, transform_instances]):
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def register(
marketplace_cert=False,
approval_status=None,
description=None,
domain=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand All @@ -175,6 +176,8 @@ def register(
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -202,6 +205,7 @@ def register(
marketplace_cert,
approval_status,
description,
domain=domain,
)

def prepare_container_def(
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def register(
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
domain=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -232,7 +233,8 @@ def register(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).

domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -262,6 +264,7 @@ def register(
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
domain=domain,
)

def deploy(
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def __init__(
container_def_list=None,
drift_check_baselines=None,
customer_metadata_properties=None,
domain=None,
**kwargs,
):
"""Constructor of a register model step.
Expand Down Expand Up @@ -326,6 +327,8 @@ def __init__(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).
**kwargs: additional arguments to `create_model`.
"""
super(_RegisterModelStep, self).__init__(
Expand Down Expand Up @@ -356,6 +359,7 @@ def __init__(
self.model_metrics = model_metrics
self.drift_check_baselines = drift_check_baselines
self.customer_metadata_properties = customer_metadata_properties
self.domain = domain
self.metadata_properties = metadata_properties
self.approval_status = approval_status
self.image_uri = image_uri
Expand Down Expand Up @@ -433,6 +437,7 @@ def arguments(self) -> RequestType:
tags=self.tags,
container_def_list=self.container_def_list,
customer_metadata_properties=self.customer_metadata_properties,
domain=self.domain,
)

request_dict = get_create_model_package_request(**model_package_args)
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
model: Union[Model, PipelineModel] = None,
drift_check_baselines=None,
customer_metadata_properties=None,
domain=None,
**kwargs,
):
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
Expand Down Expand Up @@ -122,6 +123,8 @@ def __init__(
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).

**kwargs: additional arguments to `create_model`.
"""
Expand Down Expand Up @@ -241,6 +244,7 @@ def __init__(
container_def_list=self.container_def_list,
retry_policies=register_model_step_retry_policies,
customer_metadata_properties=customer_metadata_properties,
domain=domain,
**kwargs,
)
if not repack_model:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def test_model_registration_with_drift_check_baselines(
),
)
customer_metadata_properties = {"key1": "value1"}
domain = "COMPUTER_VISION"
estimator = XGBoost(
entry_point="training.py",
source_dir=os.path.join(DATA_DIR, "sip"),
Expand All @@ -572,6 +573,7 @@ def test_model_registration_with_drift_check_baselines(
model_metrics=model_metrics,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
domain=domain,
)

pipeline = Pipeline(
Expand Down Expand Up @@ -643,6 +645,7 @@ def test_model_registration_with_drift_check_baselines(
== "application/json"
)
assert response["CustomerMetadataProperties"] == customer_metadata_properties
assert response["Domain"] == domain
break
finally:
try:
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,6 +2386,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
approval_status = ("Approved",)
description = "description"
customer_metadata_properties = {"key1": "value1"}
domain = "COMPUTER_VISION"
sagemaker_session.create_model_package_from_containers(
containers=containers,
content_types=content_types,
Expand All @@ -2400,6 +2401,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
description=description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
domain=domain,
)
expected_args = {
"ModelPackageName": model_package_name,
Expand All @@ -2417,6 +2419,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
"ModelApprovalStatus": approval_status,
"DriftCheckBaselines": drift_check_baselines,
"CustomerMetadataProperties": customer_metadata_properties,
"Domain": domain,
}
sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)

Expand Down