Skip to content

Commit 5301b31

Browse files
committed
feature: adding 'Domain' property to RegisterModel step
1 parent e4ede31 commit 5301b31

File tree

10 files changed

+47
-1
lines changed

10 files changed

+47
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,7 @@ def register(
12801280
model_name=None,
12811281
drift_check_baselines=None,
12821282
customer_metadata_properties=None,
1283+
domain=None,
12831284
**kwargs,
12841285
):
12851286
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1311,6 +1312,8 @@ def register(
13111312
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
13121313
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
13131314
metadata properties (default: None).
1315+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
1316+
"MACHINE_LEARNING" (default: None).
13141317
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
13151318
``create_model()`` to accept ``**kwargs`` to customize model creation during
13161319
deploy. For more, see the implementation docs.
@@ -1342,6 +1345,7 @@ def register(
13421345
description,
13431346
drift_check_baselines=drift_check_baselines,
13441347
customer_metadata_properties=customer_metadata_properties,
1348+
domain=domain
13451349
)
13461350

13471351
@property

src/sagemaker/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def register(
309309
drift_check_baselines=None,
310310
customer_metadata_properties=None,
311311
validation_specification=None,
312+
domain=None
312313
):
313314
"""Creates a model package for creating SageMaker models or listing on Marketplace.
314315
@@ -336,6 +337,8 @@ def register(
336337
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
337338
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
338339
metadata properties (default: None).
340+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
341+
"MACHINE_LEARNING" (default: None).
339342
340343
Returns:
341344
A `sagemaker.model.ModelPackage` instance.
@@ -365,6 +368,7 @@ def register(
365368
drift_check_baselines=drift_check_baselines,
366369
customer_metadata_properties=customer_metadata_properties,
367370
validation_specification=validation_specification,
371+
domain=domain
368372
)
369373
model_package = self.sagemaker_session.create_model_package_from_containers(
370374
**model_pkg_args

src/sagemaker/mxnet/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def register(
158158
description=None,
159159
drift_check_baselines=None,
160160
customer_metadata_properties=None,
161+
domain=None
161162
):
162163
"""Creates a model package for creating SageMaker models or listing on Marketplace.
163164
@@ -185,6 +186,8 @@ def register(
185186
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
186187
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
187188
metadata properties (default: None).
189+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
190+
"MACHINE_LEARNING" (default: None).
188191
189192
Returns:
190193
A `sagemaker.model.ModelPackage` instance.
@@ -214,6 +217,7 @@ def register(
214217
description,
215218
drift_check_baselines=drift_check_baselines,
216219
customer_metadata_properties=customer_metadata_properties,
220+
domain=domain
217221
)
218222

219223
def prepare_container_def(

src/sagemaker/pytorch/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def register(
159159
description=None,
160160
drift_check_baselines=None,
161161
customer_metadata_properties=None,
162+
domain=None
162163
):
163164
"""Creates a model package for creating SageMaker models or listing on Marketplace.
164165
@@ -186,6 +187,8 @@ def register(
186187
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
187188
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
188189
metadata properties (default: None).
190+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
191+
"MACHINE_LEARNING" (default: None).
189192
190193
Returns:
191194
A `sagemaker.model.ModelPackage` instance.
@@ -215,6 +218,7 @@ def register(
215218
description,
216219
drift_check_baselines=drift_check_baselines,
217220
customer_metadata_properties=customer_metadata_properties,
221+
domain=domain
218222
)
219223

220224
def prepare_container_def(

src/sagemaker/session.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2803,6 +2803,7 @@ def create_model_package_from_containers(
28032803
drift_check_baselines=None,
28042804
customer_metadata_properties=None,
28052805
validation_specification=None,
2806+
domain=None
28062807
):
28072808
"""Get request dictionary for CreateModelPackage API.
28082809
@@ -2830,6 +2831,8 @@ def create_model_package_from_containers(
28302831
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
28312832
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
28322833
metadata properties (default: None).
2834+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
2835+
"MACHINE_LEARNING" (default: None).
28332836
"""
28342837

28352838
model_pkg_request = get_create_model_package_request(
@@ -2848,6 +2851,7 @@ def create_model_package_from_containers(
28482851
drift_check_baselines=drift_check_baselines,
28492852
customer_metadata_properties=customer_metadata_properties,
28502853
validation_specification=validation_specification,
2854+
domain=domain
28512855
)
28522856

28532857
def submit(request):
@@ -4218,6 +4222,7 @@ def get_model_package_args(
42184222
drift_check_baselines=None,
42194223
customer_metadata_properties=None,
42204224
validation_specification=None,
4225+
domain=None
42214226
):
42224227
"""Get arguments for create_model_package method.
42234228
@@ -4248,6 +4253,8 @@ def get_model_package_args(
42484253
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
42494254
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
42504255
metadata properties (default: None).
4256+
domain (str): Domain values can be "COMPUTER_VISION, NATURAL_LANGUAGE_PROCESSING,
4257+
MACHINE_LEARNING" (default: None).
42514258
Returns:
42524259
dict: A dictionary of method argument names and values.
42534260
"""
@@ -4289,6 +4296,8 @@ def get_model_package_args(
42894296
model_package_args["customer_metadata_properties"] = customer_metadata_properties
42904297
if validation_specification is not None:
42914298
model_package_args["validation_specification"] = validation_specification
4299+
if domain is not None:
4300+
model_package_args["domain"] = domain
42924301
return model_package_args
42934302

42944303

@@ -4309,6 +4318,7 @@ def get_create_model_package_request(
43094318
drift_check_baselines=None,
43104319
customer_metadata_properties=None,
43114320
validation_specification=None,
4321+
domain=None
43124322
):
43134323
"""Get request dictionary for CreateModelPackage API.
43144324
@@ -4362,6 +4372,8 @@ def get_create_model_package_request(
43624372
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
43634373
if validation_specification:
43644374
request_dict["ValidationSpecification"] = validation_specification
4375+
if domain is not None:
4376+
request_dict["Domain"] = domain
43654377
if containers is not None:
43664378
if not all([content_types, response_types, inference_instances, transform_instances]):
43674379
raise ValueError(

src/sagemaker/tensorflow/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def register(
205205
description=None,
206206
drift_check_baselines=None,
207207
customer_metadata_properties=None,
208+
domain=None
208209
):
209210
"""Creates a model package for creating SageMaker models or listing on Marketplace.
210211
@@ -232,7 +233,8 @@ def register(
232233
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
233234
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
234235
metadata properties (default: None).
235-
236+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
237+
"MACHINE_LEARNING" (default: None).
236238
237239
Returns:
238240
A `sagemaker.model.ModelPackage` instance.
@@ -262,6 +264,7 @@ def register(
262264
description,
263265
drift_check_baselines=drift_check_baselines,
264266
customer_metadata_properties=customer_metadata_properties,
267+
domain=domain
265268
)
266269

267270
def deploy(

src/sagemaker/workflow/_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def __init__(
280280
container_def_list=None,
281281
drift_check_baselines=None,
282282
customer_metadata_properties=None,
283+
domain=None,
283284
**kwargs,
284285
):
285286
"""Constructor of a register model step.
@@ -321,6 +322,8 @@ def __init__(
321322
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
322323
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
323324
metadata properties (default: None).
325+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
326+
"MACHINE_LEARNING" (default: None).
324327
**kwargs: additional arguments to `create_model`.
325328
"""
326329
super(_RegisterModelStep, self).__init__(
@@ -351,6 +354,7 @@ def __init__(
351354
self.model_metrics = model_metrics
352355
self.drift_check_baselines = drift_check_baselines
353356
self.customer_metadata_properties = customer_metadata_properties
357+
self.domain = domain
354358
self.metadata_properties = metadata_properties
355359
self.approval_status = approval_status
356360
self.image_uri = image_uri
@@ -428,6 +432,7 @@ def arguments(self) -> RequestType:
428432
tags=self.tags,
429433
container_def_list=self.container_def_list,
430434
customer_metadata_properties=self.customer_metadata_properties,
435+
domain=self.domain
431436
)
432437

433438
request_dict = get_create_model_package_request(**model_package_args)

src/sagemaker/workflow/step_collections.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
model: Union[Model, PipelineModel] = None,
7979
drift_check_baselines=None,
8080
customer_metadata_properties=None,
81+
domain=None,
8182
**kwargs,
8283
):
8384
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -118,6 +119,8 @@ def __init__(
118119
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
119120
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
120121
metadata properties (default: None).
122+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
123+
"MACHINE_LEARNING" (default: None).
121124
122125
**kwargs: additional arguments to `create_model`.
123126
"""
@@ -236,6 +239,7 @@ def __init__(
236239
container_def_list=self.container_def_list,
237240
retry_policies=register_model_step_retry_policies,
238241
customer_metadata_properties=customer_metadata_properties,
242+
domain=domain,
239243
**kwargs,
240244
)
241245
if not repack_model:

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def test_model_registration_with_drift_check_baselines(
550550
),
551551
)
552552
customer_metadata_properties = {"key1": "value1"}
553+
domain = "COMPUTER_VISION"
553554
estimator = XGBoost(
554555
entry_point="training.py",
555556
source_dir=os.path.join(DATA_DIR, "sip"),
@@ -572,6 +573,7 @@ def test_model_registration_with_drift_check_baselines(
572573
model_metrics=model_metrics,
573574
drift_check_baselines=drift_check_baselines,
574575
customer_metadata_properties=customer_metadata_properties,
576+
domain=domain
575577
)
576578

577579
pipeline = Pipeline(
@@ -643,6 +645,7 @@ def test_model_registration_with_drift_check_baselines(
643645
== "application/json"
644646
)
645647
assert response["CustomerMetadataProperties"] == customer_metadata_properties
648+
assert response["Domain"] == domain
646649
break
647650
finally:
648651
try:

tests/unit/test_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,6 +2386,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
23862386
approval_status = ("Approved",)
23872387
description = "description"
23882388
customer_metadata_properties = {"key1": "value1"}
2389+
domain = "COMPUTER_VISION"
23892390
sagemaker_session.create_model_package_from_containers(
23902391
containers=containers,
23912392
content_types=content_types,
@@ -2400,6 +2401,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
24002401
description=description,
24012402
drift_check_baselines=drift_check_baselines,
24022403
customer_metadata_properties=customer_metadata_properties,
2404+
domain=domain
24032405
)
24042406
expected_args = {
24052407
"ModelPackageName": model_package_name,
@@ -2417,6 +2419,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
24172419
"ModelApprovalStatus": approval_status,
24182420
"DriftCheckBaselines": drift_check_baselines,
24192421
"CustomerMetadataProperties": customer_metadata_properties,
2422+
"Domain": domain
24202423
}
24212424
sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)
24222425

0 commit comments

Comments
 (0)