Skip to content

Commit 3662c8d

Browse files
authored
fix: tags for jumpstart model package models (#4061)
1 parent d83c7ce commit 3662c8d

File tree

5 files changed

+62
-12
lines changed

5 files changed

+62
-12
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,34 @@ def _is_valid_model_id_hook():
310310

311311
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
312312

313-
def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-argument
313+
def _create_sagemaker_model(
314+
self,
315+
instance_type=None,
316+
accelerator_type=None,
317+
tags=None,
318+
serverless_inference_config=None,
319+
**kwargs,
320+
):
314321
"""Create a SageMaker Model Entity
315322
316323
Args:
317-
args: Positional arguments coming from the caller. This class does not require
318-
any so they are ignored.
319-
324+
instance_type (str): Optional. The EC2 instance type that this Model will be
325+
used for, this is only used to determine if the image needs GPU
326+
support or not. (Default: None).
327+
accelerator_type (str): Optional. Type of Elastic Inference accelerator to
328+
attach to an endpoint for model loading and inference, for
329+
example, 'ml.eia1.medium'. If not specified, no Elastic
330+
Inference accelerator will be attached to the endpoint. (Default: None).
331+
tags (List[dict[str, str]]): Optional. The list of tags to add to
332+
the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
333+
'tagvalue'}] For more information about tags, see
334+
https://boto3.amazonaws.com/v1/documentation
335+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
336+
(Default: None).
337+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
338+
Optional. Specifies configuration related to serverless endpoint. Instance type is
339+
not provided in serverless inference. So this is used to find image URIs.
340+
(Default: None).
320341
kwargs: Keyword arguments coming from the caller. This class does not require
321342
any so they are ignored.
322343
"""
@@ -347,10 +368,16 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
347368
container_def,
348369
vpc_config=self.vpc_config,
349370
enable_network_isolation=self.enable_network_isolation(),
350-
tags=kwargs.get("tags"),
371+
tags=tags,
351372
)
352373
else:
353-
super(JumpStartModel, self)._create_sagemaker_model(*args, **kwargs)
374+
super(JumpStartModel, self)._create_sagemaker_model(
375+
instance_type=instance_type,
376+
accelerator_type=accelerator_type,
377+
tags=tags,
378+
serverless_inference_config=serverless_inference_config,
379+
**kwargs,
380+
)
354381

355382
def deploy(
356383
self,

src/sagemaker/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1375,7 +1375,10 @@ def deploy(
13751375
self._base_name = "-".join((self._base_name, compiled_model_suffix))
13761376

13771377
self._create_sagemaker_model(
1378-
instance_type, accelerator_type, tags, serverless_inference_config
1378+
instance_type=instance_type,
1379+
accelerator_type=accelerator_type,
1380+
tags=tags,
1381+
serverless_inference_config=serverless_inference_config,
13791382
)
13801383

13811384
serverless_inference_config_dict = (

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,10 @@ def test_jumpstart_model_package_arn(
589589

590590
model = JumpStartModel(model_id=model_id)
591591

592-
model.deploy()
592+
tag = {"Key": "foo", "Value": "bar"}
593+
tags = [tag]
594+
595+
model.deploy(tags=tags)
593596

594597
self.assertEqual(
595598
mock_session.return_value.create_model.call_args[0][2],
@@ -599,6 +602,8 @@ def test_jumpstart_model_package_arn(
599602
},
600603
)
601604

605+
self.assertIn(tag, mock_session.return_value.create_model.call_args[1]["tags"])
606+
602607
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
603608
@mock.patch("sagemaker.jumpstart.factory.model.Session")
604609
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")

tests/unit/sagemaker/model/test_deploy.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ def test_deploy_accelerator_type(
159159
accelerator_type=ACCELERATOR_TYPE,
160160
)
161161

162-
create_sagemaker_model.assert_called_with(INSTANCE_TYPE, ACCELERATOR_TYPE, None, None)
162+
create_sagemaker_model.assert_called_with(
163+
instance_type=INSTANCE_TYPE,
164+
accelerator_type=ACCELERATOR_TYPE,
165+
tags=None,
166+
serverless_inference_config=None,
167+
)
163168
production_variant.assert_called_with(
164169
MODEL_NAME,
165170
INSTANCE_TYPE,
@@ -271,7 +276,12 @@ def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base,
271276
tags = [{"Key": "ModelName", "Value": "TestModel"}]
272277
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, tags=tags)
273278

274-
create_sagemaker_model.assert_called_with(INSTANCE_TYPE, None, tags, None)
279+
create_sagemaker_model.assert_called_with(
280+
instance_type=INSTANCE_TYPE,
281+
accelerator_type=None,
282+
tags=tags,
283+
serverless_inference_config=None,
284+
)
275285
sagemaker_session.endpoint_from_production_variants.assert_called_with(
276286
name=ENDPOINT_NAME,
277287
production_variants=[BASE_PRODUCTION_VARIANT],
@@ -463,7 +473,12 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model,
463473
serverless_inference_config=serverless_inference_config,
464474
)
465475

466-
create_sagemaker_model.assert_called_with(None, None, None, serverless_inference_config)
476+
create_sagemaker_model.assert_called_with(
477+
instance_type=None,
478+
accelerator_type=None,
479+
tags=None,
480+
serverless_inference_config=serverless_inference_config,
481+
)
467482
production_variant.assert_called_with(
468483
MODEL_NAME,
469484
None,

tests/unit/sagemaker/model/test_model_package.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_create_sagemaker_model_include_tags(sagemaker_session):
197197
sagemaker_session=sagemaker_session,
198198
)
199199

200-
model_package._create_sagemaker_model(tags=tags)
200+
model_package.deploy(tags=tags, instance_type="ml.p2.xlarge", initial_instance_count=1)
201201

202202
sagemaker_session.create_model.assert_called_with(
203203
model_name,

0 commit comments

Comments
 (0)