Skip to content

Commit 07dd11a

Browse files
chuyang-dengDan Choi
authored andcommitted
fix: add MetadataProperties to create_model_package request (aws#533)
1 parent ffab145 commit 07dd11a

File tree

9 files changed

+85
-1
lines changed

9 files changed

+85
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,7 @@ def register(
813813
model_package_name=None,
814814
model_package_group_name=None,
815815
model_metrics=None,
816+
metadata_properties=None,
816817
marketplace_cert=False,
817818
approval_status=None,
818819
description=None,
@@ -837,6 +838,7 @@ def register(
837838
`model_package_name`, using `model_package_group_name` makes the Model Package
838839
versioned (default: None).
839840
model_metrics (ModelMetrics): ModelMetrics object (default: None).
841+
metadata_properties (MetadataProperties): MetadataProperties (default: None).
840842
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
841843
for AWS Marketplace (default: False).
842844
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
@@ -868,6 +870,7 @@ def register(
868870
model_package_group_name,
869871
image_uri,
870872
model_metrics,
873+
metadata_properties,
871874
marketplace_cert,
872875
approval_status,
873876
description,

src/sagemaker/metadata_properties.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This file contains code related to metadata properties."""
14+
from __future__ import absolute_import
15+
16+
17+
class MetadataProperties(object):
18+
"""Accepts metadata properties parameters for conversion to request dict."""
19+
20+
def __init__(
21+
self,
22+
commit_id=None,
23+
repository=None,
24+
generated_by=None,
25+
project_id=None,
26+
):
27+
"""Initialize a ``MetadataProperties`` instance and turn parameters into dict.
28+
29+
# TODO: flesh out docstrings
30+
Args:
31+
commit_id (str):
32+
repository (str):
33+
generated_by (str):
34+
project_id (str):
35+
"""
36+
self.commit_id = commit_id
37+
self.repository = repository
38+
self.generated_by = generated_by
39+
self.project_id = project_id
40+
41+
def _to_request_dict(self):
42+
"""Generates a request dictionary using the parameters provided to the class."""
43+
metadata_properties_request = dict()
44+
if self.commit_id:
45+
metadata_properties_request["CommitId"] = self.commit_id
46+
if self.repository:
47+
metadata_properties_request["Repository"] = self.repository
48+
if self.generated_by:
49+
metadata_properties_request["GeneratedBy"] = self.generated_by
50+
if self.project_id:
51+
metadata_properties_request["ProjectId"] = self.project_id
52+
return metadata_properties_request

src/sagemaker/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def register(
114114
model_package_group_name=None,
115115
image_uri=None,
116116
model_metrics=None,
117+
metadata_properties=None,
117118
marketplace_cert=False,
118119
approval_status=None,
119120
description=None,
@@ -135,6 +136,7 @@ def register(
135136
image_uri (str): Inference image uri for the container. Model class' self.image will
136137
be used if it is None (default: None).
137138
model_metrics (ModelMetrics): ModelMetrics object (default: None).
139+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
138140
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
139141
for AWS Marketplace (default: False).
140142
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
@@ -156,6 +158,7 @@ def register(
156158
model_package_group_name,
157159
image_uri,
158160
model_metrics,
161+
metadata_properties,
159162
marketplace_cert,
160163
approval_status,
161164
description,
@@ -179,6 +182,7 @@ def _get_model_package_args(
179182
model_package_group_name=None,
180183
image_uri=None,
181184
model_metrics=None,
185+
metadata_properties=None,
182186
marketplace_cert=False,
183187
approval_status=None,
184188
description=None,
@@ -199,6 +203,7 @@ def _get_model_package_args(
199203
image_uri (str): Inference image uri for the container. Model class' self.image will
200204
be used if it is None (default: None).
201205
model_metrics (ModelMetrics): ModelMetrics object (default: None).
206+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
202207
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
203208
for AWS Marketplace (default: False).
204209
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
@@ -229,6 +234,8 @@ def _get_model_package_args(
229234
model_package_args["model_package_group_name"] = model_package_group_name
230235
if model_metrics is not None:
231236
model_package_args["model_metrics"] = model_metrics._to_request_dict()
237+
if metadata_properties is not None:
238+
model_package_args["metadata_properties"] = metadata_properties._to_request_dict()
232239
if approval_status is not None:
233240
model_package_args["approval_status"] = approval_status
234241
if description is not None:

src/sagemaker/model_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of

src/sagemaker/mxnet/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def register(
140140
model_package_group_name=None,
141141
image_uri=None,
142142
model_metrics=None,
143+
metadata_properties=None,
143144
marketplace_cert=False,
144145
approval_status=None,
145146
description=None,
@@ -161,6 +162,7 @@ def register(
161162
image_uri (str): Inference image uri for the container. Model class' self.image will
162163
be used if it is None (default: None).
163164
model_metrics (ModelMetrics): ModelMetrics object (default: None).
165+
metadata_properties (MetadataProperties): MetadataProperties (default: None).
164166
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
165167
for AWS Marketplace (default: False).
166168
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
@@ -189,6 +191,7 @@ def register(
189191
model_package_group_name,
190192
image_uri,
191193
model_metrics,
194+
metadata_properties,
192195
marketplace_cert,
193196
approval_status,
194197
description,

src/sagemaker/pytorch/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def register(
140140
model_package_group_name=None,
141141
image_uri=None,
142142
model_metrics=None,
143+
metadata_properties=None,
143144
marketplace_cert=False,
144145
approval_status=None,
145146
description=None,
@@ -161,6 +162,7 @@ def register(
161162
image_uri (str): Inference image uri for the container. Model class' self.image will
162163
be used if it is None (default: None).
163164
model_metrics (ModelMetrics): ModelMetrics object (default: None).
165+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
164166
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
165167
for AWS Marketplace (default: False).
166168
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
@@ -189,6 +191,7 @@ def register(
189191
model_package_group_name,
190192
image_uri,
191193
model_metrics,
194+
metadata_properties,
192195
marketplace_cert,
193196
approval_status,
194197
description,

src/sagemaker/tensorflow/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def register(
212212
model_package_group_name=None,
213213
image_uri=None,
214214
model_metrics=None,
215+
metadata_properties=None,
215216
marketplace_cert=False,
216217
approval_status=None,
217218
description=None,
@@ -233,6 +234,7 @@ def register(
233234
image_uri (str): Inference image uri for the container. Model class' self.image will
234235
be used if it is None (default: None).
235236
model_metrics (ModelMetrics): ModelMetrics object (default: None).
237+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
236238
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
237239
for AWS Marketplace (default: False).
238240
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
@@ -261,6 +263,7 @@ def register(
261263
model_package_group_name,
262264
image_uri,
263265
model_metrics,
266+
metadata_properties,
264267
marketplace_cert,
265268
approval_status,
266269
description,

src/sagemaker/workflow/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def __init__(
216216
transform_instances,
217217
model_package_group_name=None,
218218
model_metrics=None,
219+
metadata_properties=None,
219220
approval_status="PendingManualApproval",
220221
image_uri=None,
221222
compile_model_family=None,
@@ -238,6 +239,7 @@ def __init__(
238239
`model_package_name`, using `model_package_group_name` makes the Model Package
239240
versioned (default: None).
240241
model_metrics (ModelMetrics): ModelMetrics object (default: None).
242+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
241243
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
242244
or "PendingManualApproval" (default: "PendingManualApproval").
243245
image_uri (str): The container image uri for Model Package, if not specified,
@@ -255,6 +257,7 @@ def __init__(
255257
self.transform_instances = transform_instances
256258
self.model_package_group_name = model_package_group_name
257259
self.model_metrics = model_metrics
260+
self.metadata_properties = metadata_properties
258261
self.approval_status = approval_status
259262
self.image_uri = image_uri
260263
self.compile_model_family = compile_model_family
@@ -309,6 +312,7 @@ def arguments(self) -> RequestType:
309312
transform_instances=self.transform_instances,
310313
model_package_group_name=self.model_package_group_name,
311314
model_metrics=self.model_metrics,
315+
metadata_properties=self.metadata_properties,
312316
approval_status=self.approval_status,
313317
)
314318
request_dict = model.sagemaker_session._get_create_model_package_request(

tests/unit/test_mxnet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pkg_resources import parse_version
2323

2424
from sagemaker.fw_utils import UploadedCode
25+
from sagemaker.metadata_properties import MetadataProperties
2526
from sagemaker.model_metrics import MetricsSource, ModelMetrics
2627
from sagemaker.mxnet import defaults
2728
from sagemaker.mxnet import MXNet
@@ -544,13 +545,20 @@ def test_model_register_all_args(
544545
bias=dummy_metrics_source,
545546
explainability=dummy_metrics_source,
546547
)
548+
metadata_properties = MetadataProperties(
549+
commit_id="test-commit-id",
550+
repository="test-repository",
551+
generated_by="sagemaker-python-sdk-test",
552+
project_id="test-project-id",
553+
)
547554
model.register(
548555
content_types,
549556
response_types,
550557
inference_instances,
551558
transform_instances,
552559
model_package_name=model_package_name,
553560
model_metrics=model_metrics,
561+
metadata_properties=metadata_properties,
554562
marketplace_cert=True,
555563
approval_status="Approved",
556564
description="description",
@@ -563,6 +571,7 @@ def test_model_register_all_args(
563571
"transform_instances": transform_instances,
564572
"model_package_name": model_package_name,
565573
"model_metrics": model_metrics._to_request_dict(),
574+
"metadata_properties": metadata_properties._to_request_dict(),
566575
"marketplace_cert": True,
567576
"approval_status": "Approved",
568577
"description": "description",

0 commit comments

Comments
 (0)