Skip to content

Commit 0c31aa7

Browse files
author
Basil Beirouti
committed
add validation specification
1 parent 4a94ace commit 0c31aa7

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

src/sagemaker/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def register(
305305
description=None,
306306
drift_check_baselines=None,
307307
customer_metadata_properties=None,
308+
validation_specification=None
308309
):
309310
"""Creates a model package for creating SageMaker models or listing on Marketplace.
310311
@@ -360,6 +361,7 @@ def register(
360361
container_def_list=[container_def],
361362
drift_check_baselines=drift_check_baselines,
362363
customer_metadata_properties=customer_metadata_properties,
364+
validation_specification=validation_specification
363365
)
364366
model_package = self.sagemaker_session.create_model_package_from_containers(
365367
**model_pkg_args

src/sagemaker/session.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2787,6 +2787,7 @@ def create_model_package_from_containers(
27872787
description=None,
27882788
drift_check_baselines=None,
27892789
customer_metadata_properties=None,
2790+
validation_specification=None
27902791
):
27912792
"""Get request dictionary for CreateModelPackage API.
27922793
@@ -2832,6 +2833,7 @@ def create_model_package_from_containers(
28322833
description,
28332834
drift_check_baselines=drift_check_baselines,
28342835
customer_metadata_properties=customer_metadata_properties,
2836+
validation_specification=validation_specification
28352837
)
28362838
if model_package_group_name is not None:
28372839
try:
@@ -4180,6 +4182,7 @@ def get_model_package_args(
41804182
container_def_list=None,
41814183
drift_check_baselines=None,
41824184
customer_metadata_properties=None,
4185+
validation_specification=None
41834186
):
41844187
"""Get arguments for create_model_package method.
41854188
@@ -4249,6 +4252,8 @@ def get_model_package_args(
42494252
model_package_args["tags"] = tags
42504253
if customer_metadata_properties is not None:
42514254
model_package_args["customer_metadata_properties"] = customer_metadata_properties
4255+
if validation_specification is not None:
4256+
model_package_args["validation_specification"] = validation_specification
42524257
return model_package_args
42534258

42544259

@@ -4268,6 +4273,7 @@ def get_create_model_package_request(
42684273
tags=None,
42694274
drift_check_baselines=None,
42704275
customer_metadata_properties=None,
4276+
validation_specification=None
42714277
):
42724278
"""Get request dictionary for CreateModelPackage API.
42734279
@@ -4319,6 +4325,8 @@ def get_create_model_package_request(
43194325
request_dict["MetadataProperties"] = metadata_properties
43204326
if customer_metadata_properties is not None:
43214327
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
4328+
if validation_specification:
4329+
request_dict["ValidationSpecification"] = validation_specification
43224330
if containers is not None:
43234331
if not all([content_types, response_types, inference_instances, transform_instances]):
43244332
raise ValueError(

0 commit comments

Comments
 (0)