Skip to content

Commit 2b8727e

Browse files
authored
Merge branch 'master' into callback-param-bug-fix
2 parents b3bd0bb + d341d76 commit 2b8727e

File tree

6 files changed

+25
-0
lines changed

6 files changed

+25
-0
lines changed

doc/amazon_sagemaker_featurestore.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,13 @@ example identifier to retrieve the record.
291291
record_identifier_value = str(2990130)
292292
featurestore_runtime.get_record(FeatureGroupName=transaction_feature_group_name, RecordIdentifierValueAsString=record_identifier_value)
293293
294+
You can use the ``batch_get_record`` function to retrieve multiple records simultaneously from your feature store. The following example uses this API to retrieve a batch of records.
295+
296+
.. code:: python
297+
298+
record_identifier_values = ["573291", "109382", "828400", "124013"]
299+
featurestore_runtime.batch_get_record(Identifiers=[{"FeatureGroupName": transaction_feature_group_name, "RecordIdentifiersValueAsString": record_identifier_values}])
300+
294301
An example response from the fraud detection example:
295302
296303
.. code:: python

src/sagemaker/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def _get_model_package_args(
195195
marketplace_cert=False,
196196
approval_status=None,
197197
description=None,
198+
tags=None,
198199
):
199200
"""Get arguments for session.create_model_package method.
200201
@@ -250,6 +251,8 @@ def _get_model_package_args(
250251
model_package_args["approval_status"] = approval_status
251252
if description is not None:
252253
model_package_args["description"] = description
254+
if tags is not None:
255+
model_package_args["tags"] = tags
253256
return model_package_args
254257

255258
def _init_sagemaker_session_if_does_not_exist(self, instance_type):

src/sagemaker/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2724,6 +2724,7 @@ def _get_create_model_package_request(
27242724
marketplace_cert=False,
27252725
approval_status="PendingManualApproval",
27262726
description=None,
2727+
tags=None,
27272728
):
27282729
"""Get request dictionary for CreateModelPackage API.
27292730
@@ -2761,6 +2762,8 @@ def _get_create_model_package_request(
27612762
request_dict["ModelPackageGroupName"] = model_package_group_name
27622763
if description is not None:
27632764
request_dict["ModelPackageDescription"] = description
2765+
if tags is not None:
2766+
request_dict["Tags"] = tags
27642767
if model_metrics:
27652768
request_dict["ModelMetrics"] = model_metrics
27662769
if metadata_properties:

src/sagemaker/workflow/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def __init__(
225225
compile_model_family=None,
226226
description=None,
227227
depends_on: List[str] = None,
228+
tags=None,
228229
**kwargs,
229230
):
230231
"""Constructor of a register model step.
@@ -264,6 +265,7 @@ def __init__(
264265
self.inference_instances = inference_instances
265266
self.transform_instances = transform_instances
266267
self.model_package_group_name = model_package_group_name
268+
self.tags = tags
267269
self.model_metrics = model_metrics
268270
self.metadata_properties = metadata_properties
269271
self.approval_status = approval_status
@@ -324,10 +326,12 @@ def arguments(self) -> RequestType:
324326
metadata_properties=self.metadata_properties,
325327
approval_status=self.approval_status,
326328
description=self.description,
329+
tags=self.tags,
327330
)
328331
request_dict = model.sagemaker_session._get_create_model_package_request(
329332
**model_package_args
330333
)
334+
331335
# these are not available in the workflow service and will cause rejection
332336
if "CertifyForMarketplace" in request_dict:
333337
request_dict.pop("CertifyForMarketplace")

src/sagemaker/workflow/step_collections.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
image_uri=None,
6868
compile_model_family=None,
6969
description=None,
70+
tags=None,
7071
**kwargs,
7172
):
7273
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -94,6 +95,10 @@ def __init__(
9495
compile_model_family (str): The instance family for the compiled model. If
9596
specified, a compiled model is used (default: None).
9697
description (str): Model Package description (default: None).
98+
tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note
99+
that tags will only be applied to newly created model package groups; if the
100+
name of an existing group is passed to "model_package_group_name",
101+
tags will not be applied.
97102
**kwargs: additional arguments to `create_model`.
98103
"""
99104
steps: List[Step] = []
@@ -134,6 +139,7 @@ def __init__(
134139
image_uri=image_uri,
135140
compile_model_family=compile_model_family,
136141
description=description,
142+
tags=tags,
137143
**kwargs,
138144
)
139145
if not repack_model:

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def test_register_model(estimator, model_metrics):
182182
approval_status="Approved",
183183
description="description",
184184
depends_on=["TestStep"],
185+
tags=[{"Key": "myKey", "Value": "myValue"}],
185186
)
186187
assert ordered(register_model.request_dicts()) == ordered(
187188
[
@@ -210,6 +211,7 @@ def test_register_model(estimator, model_metrics):
210211
},
211212
"ModelPackageDescription": "description",
212213
"ModelPackageGroupName": "mpg",
214+
"Tags": [{"Key": "myKey", "Value": "myValue"}],
213215
},
214216
},
215217
]

0 commit comments

Comments
 (0)