Skip to content

Commit e735e3f

Browse files
committed
add doc string and fix unit tests
1 parent 5d97eb6 commit e735e3f

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,7 @@ def _search_recommendation(self, recommendation_list, inference_recommendation_i
616616
)
617617

618618
def _add_client_type_tag(self, tags, client_type):
619+
"""Tagging for Inference Recommender and Deployment Recommendations"""
619620
client_type_tag = {"Key": "ClientType", "Value": client_type}
620621
tags = tags.append(client_type_tag) if tags else [client_type_tag]
621622
return tags

src/sagemaker/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,10 +1738,10 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
17381738
17391739
Args:
17401740
args: Positional arguments coming from the caller. This class does not require
1741-
any but will specifically look for Tags (3rd arg positionally) if specified
1741+
any but will look for tags in the 3rd parameter.
17421742
17431743
kwargs: Keyword arguments coming from the caller. This class does not require
1744-
any so they are ignored.
1744+
any but will search for tags if not in args.
17451745
"""
17461746
if self.algorithm_arn:
17471747
# When ModelPackage is created using an algorithm_arn we need to first
@@ -1763,13 +1763,17 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
17631763
self._ensure_base_name_if_needed(model_package_name.split("/")[-1])
17641764
self._set_model_name_if_needed()
17651765

1766+
# If tags are in args, it must be the 3rd param
1767+
# If not, then check kwargs and set to either tags or None
1768+
tags = args[2] if len(args) >= 3 else kwargs.get('tags')
1769+
17661770
self.sagemaker_session.create_model(
17671771
self.name,
17681772
self.role,
17691773
container_def,
17701774
vpc_config=self.vpc_config,
17711775
enable_network_isolation=self.enable_network_isolation(),
1772-
tags=args[2],
1776+
tags=tags,
17731777
)
17741778

17751779
def _ensure_base_name_if_needed(self, base_name):

0 commit comments

Comments
 (0)