Skip to content

Commit b7a2b9c

Browse files
authored
fix: model.transformer() passes tags to create_model() (#976)
* fix: model.transformer() passes tags to create_model() * ignore kwargs in ModelPackage.create_sagemaker_model
1 parent 6502810 commit b7a2b9c

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/sagemaker/model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def transformer(
498498
volume_kms_key (str): Optional. KMS key ID for encrypting the volume
499499
attached to the ML compute instance (default: None).
500500
"""
501-
self._create_sagemaker_model(instance_type)
501+
self._create_sagemaker_model(instance_type, tags=tags)
502502
if self.enable_network_isolation():
503503
env = None
504504

@@ -895,11 +895,14 @@ def _is_marketplace(self):
895895
return True
896896
return False
897897

898-
def _create_sagemaker_model(self, *args): # pylint: disable=unused-argument
898+
def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-argument
899899
"""Create a SageMaker Model Entity
900900
901901
Args:
902-
*args: Arguments coming from the caller. This class does not require
902+
args: Positional arguments coming from the caller. This class does not require
903+
any so they are ignored.
904+
905+
kwargs: Keyword arguments coming from the caller. This class does not require
903906
any so they are ignored.
904907
"""
905908
if self.algorithm_arn:

tests/unit/test_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,20 @@ def test_model_create_transformer(sagemaker_session):
361361
return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE
362362
)
363363

364+
tags = [{"Key": "k", "Value": "v"}]
364365
model = DummyFrameworkModel(sagemaker_session=sagemaker_session)
366+
instance_type = "ml.m4.xlarge"
365367
model.name = "auto-generated-model"
366368
transformer = model.transformer(
367-
instance_count=1, instance_type="ml.m4.xlarge", env={"test": True}
369+
instance_count=1, instance_type=instance_type, env={"test": True}, tags=tags
368370
)
369371
assert isinstance(transformer, sagemaker.transformer.Transformer)
370372
assert transformer.model_name == "auto-generated-model"
371373
assert transformer.instance_type == "ml.m4.xlarge"
372374
assert transformer.env == {"test": True}
373375

376+
sagemaker.model.Model._create_sagemaker_model.assert_called_with(instance_type, tags=tags)
377+
374378

375379
def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session):
376380
sagemaker_session.sagemaker_client.describe_model_package = Mock(

0 commit comments

Comments
 (0)