Skip to content

Commit 7f9782f

Browse files
committed
add tags for inf and dep recommendations
1 parent 90d5f19 commit 7f9782f

File tree

6 files changed

+280
-25
lines changed

6 files changed

+280
-25
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030

3131
LOGGER = logging.getLogger("sagemaker")
3232

33+
DEPLOYMENT_RECOMMENDATION_TAG = "PythonSDK-DeploymentRecommendation"
34+
35+
RIGHT_SIZE_TAG = "PythonSDK-RightSize"
3336

3437
class Phase:
3538
"""Used to store phases of a traffic pattern to perform endpoint load testing.
@@ -218,6 +221,7 @@ def _update_params(
218221
explainer_config = kwargs["explainer_config"]
219222
inference_recommendation_id = kwargs["inference_recommendation_id"]
220223
inference_recommender_job_results = kwargs["inference_recommender_job_results"]
224+
tags = kwargs["tags"]
221225
if inference_recommendation_id is not None:
222226
inference_recommendation = self._update_params_for_recommendation_id(
223227
instance_type=instance_type,
@@ -237,7 +241,11 @@ def _update_params(
237241
async_inference_config,
238242
explainer_config,
239243
)
240-
return inference_recommendation or (instance_type, initial_instance_count)
244+
245+
if inference_recommendation:
246+
tags = self._add_client_type_tag(tags, inference_recommendation[2])
247+
return (inference_recommendation[0], inference_recommendation[1], tags)
248+
return (instance_type, initial_instance_count, tags)
241249

242250
def _update_params_for_right_size(
243251
self,
@@ -301,7 +309,7 @@ def _update_params_for_right_size(
301309
initial_instance_count = self.inference_recommendations[0]["EndpointConfiguration"][
302310
"InitialInstanceCount"
303311
]
304-
return (instance_type, initial_instance_count)
312+
return (instance_type, initial_instance_count, "PythonSDK-RightSize")
305313

306314
def _update_params_for_recommendation_id(
307315
self,
@@ -401,7 +409,7 @@ def _update_params_for_recommendation_id(
401409
raise ValueError("Must specify model recommendation id and instance count.")
402410
self.env.update(model_recommendation["Environment"])
403411
instance_type = model_recommendation["InstanceType"]
404-
return (instance_type, initial_instance_count)
412+
return (instance_type, initial_instance_count, DEPLOYMENT_RECOMMENDATION_TAG)
405413

406414
# Update params based on default inference recommendation
407415
if bool(instance_type) != bool(initial_instance_count):
@@ -465,7 +473,7 @@ def _update_params_for_recommendation_id(
465473
"InitialInstanceCount"
466474
]
467475

468-
return (instance_type, initial_instance_count)
476+
return (instance_type, initial_instance_count, RIGHT_SIZE_TAG)
469477

470478
def _convert_to_endpoint_configurations_json(
471479
self, hyperparameter_ranges: List[Dict[str, CategoricalParameter]]
@@ -605,3 +613,11 @@ def _search_recommendation(self, recommendation_list, inference_recommendation_i
605613
),
606614
None,
607615
)
616+
617+
def _add_client_type_tag(self, tags, client_type):
618+
client_type_tag = {
619+
"Key": "ClientType",
620+
"Value": client_type
621+
}
622+
tags = tags.append(client_type_tag) if tags else [client_type_tag]
623+
return tags

src/sagemaker/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,7 @@ def deploy(
12181218
inference_recommendation_id is not None
12191219
or self.inference_recommender_job_results is not None
12201220
):
1221-
instance_type, initial_instance_count = self._update_params(
1221+
instance_type, initial_instance_count, tags = self._update_params(
12221222
instance_type=instance_type,
12231223
initial_instance_count=initial_instance_count,
12241224
accelerator_type=accelerator_type,
@@ -1227,6 +1227,7 @@ def deploy(
12271227
explainer_config=explainer_config,
12281228
inference_recommendation_id=inference_recommendation_id,
12291229
inference_recommender_job_results=self.inference_recommender_job_results,
1230+
tags=tags,
12301231
)
12311232

12321233
is_async = async_inference_config is not None
@@ -1721,7 +1722,7 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
17211722
17221723
Args:
17231724
args: Positional arguments coming from the caller. This class does not require
1724-
any so they are ignored.
1725+
any but will specifically look for Tags (3rd arg positionally) if specified
17251726
17261727
kwargs: Keyword arguments coming from the caller. This class does not require
17271728
any so they are ignored.
@@ -1752,6 +1753,7 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
17521753
container_def,
17531754
vpc_config=self.vpc_config,
17541755
enable_network_isolation=self.enable_network_isolation(),
1756+
tags=args[2],
17551757
)
17561758

17571759
def _ensure_base_name_if_needed(self, base_name):

tests/integ/test_inference_recommender.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,30 @@ def default_right_sized_unregistered_base_model(sagemaker_session, cpu_instance_
279279
sagemaker_session.delete_model(ModelName=model.name)
280280

281281

282+
@pytest.fixture(scope="module")
283+
def created_base_model(sagemaker_session, cpu_instance_type):
284+
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL)
285+
region = sagemaker_session._region_name
286+
image_uri = image_uris.retrieve(
287+
framework="sklearn", region=region, version="1.0-1", image_scope="inference"
288+
)
289+
290+
iam_client = sagemaker_session.boto_session.client("iam")
291+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
292+
293+
model = Model(
294+
model_data=model_data,
295+
role=role_arn,
296+
entry_point=IR_SKLEARN_ENTRY_POINT,
297+
image_uri=image_uri,
298+
sagemaker_session=sagemaker_session,
299+
)
300+
301+
model.create(instance_type=cpu_instance_type)
302+
303+
return model
304+
305+
282306
@pytest.mark.slow_test
283307
def test_default_right_size_and_deploy_registered_model_sklearn(
284308
default_right_sized_model, sagemaker_session
@@ -429,3 +453,48 @@ def test_deploy_inference_recommendation_id_with_registered_model_sklearn(
429453
)
430454
predictor.delete_model()
431455
predictor.delete_endpoint()
456+
457+
458+
@pytest.mark.slow_test
459+
def test_deploy_deployment_recommendation_id_with_model(created_base_model, sagemaker_session):
460+
with timeout(minutes=20):
461+
try:
462+
deployment_recommendation = poll_for_deployment_recommendation(created_base_model, sagemaker_session)
463+
464+
assert deployment_recommendation != None
465+
466+
real_time_recommendations = deployment_recommendation.get("RealTimeInferenceRecommendations")
467+
recommendation_id = real_time_recommendations[0].get('RecommendationId')
468+
469+
endpoint_name = unique_name_from_base("test-rec-id-deployment-default-sklearn")
470+
created_base_model.predictor_cls = SKLearnPredictor
471+
predictor = created_base_model.deploy(
472+
inference_recommendation_id=recommendation_id, initial_instance_count=1, endpoint_name=endpoint_name
473+
)
474+
475+
payload = pd.read_csv(IR_SKLEARN_DATA, header=None)
476+
477+
inference = predictor.predict(payload)
478+
assert inference is not None
479+
assert 26 == len(inference)
480+
finally:
481+
predictor.delete_model()
482+
predictor.delete_endpoint()
483+
484+
485+
def poll_for_deployment_recommendation(created_base_model, sagemaker_session):
486+
with timeout(minutes=1):
487+
try:
488+
completed = False
489+
while not completed:
490+
describe_model_response = sagemaker_session.sagemaker_client.describe_model(ModelName=created_base_model.name)
491+
deployment_recommendation = describe_model_response.get("DeploymentRecommendation")
492+
493+
completed = (
494+
deployment_recommendation is not None
495+
and "COMPLETED" == deployment_recommendation.get("RecommendationStatus")
496+
)
497+
return deployment_recommendation
498+
except Exception as e:
499+
created_base_model.delete_model()
500+
raise e

tests/unit/sagemaker/inference_recommender/constants.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,44 @@
152152
"ModelArtifacts": {"S3ModelArtifacts": IR_COMPILATION_MODEL_DATA},
153153
"InferenceImage": IR_COMPILATION_IMAGE,
154154
}
155+
156+
IR_CONTAINER_DEF = {
157+
"Image": IR_IMAGE,
158+
"Environment": IR_ENV,
159+
"ModelDataUrl": IR_MODEL_DATA,
160+
}
161+
162+
DEPLOYMENT_RECOMMENDATION_CONTAINER_DEF = {
163+
"Image": IR_IMAGE,
164+
"Environment": MODEL_RECOMMENDATION_ENV,
165+
"ModelDataUrl": IR_MODEL_DATA,
166+
}
167+
168+
IR_COMPILATION_CONTAINER_DEF = {
169+
"Image": IR_COMPILATION_IMAGE,
170+
"Environment": {},
171+
"ModelDataUrl": IR_COMPILATION_MODEL_DATA,
172+
}
173+
174+
IR_MODEL_PACKAGE_CONTAINER_DEF = {
175+
"ModelPackageName": IR_MODEL_PACKAGE_VERSION_ARN,
176+
"Environment": IR_ENV,
177+
}
178+
179+
IR_COMPILATION_MODEL_PACKAGE_CONTAINER_DEF = {
180+
"ModelPackageName": IR_MODEL_PACKAGE_VERSION_ARN,
181+
}
182+
183+
IR_TAGS = [
184+
{
185+
"Key": "ClientType",
186+
"Value": "PythonSDK-RightSize",
187+
}
188+
]
189+
190+
DEPLOYMENT_RECOMMENDATION_TAGS = [
191+
{
192+
"Key": "ClientType",
193+
"Value": "PythonSDK-DeploymentRecommendation",
194+
}
195+
]

tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,17 @@
177177
}
178178
]
179179

180+
IR_MODEL_PACKAGE_CONTAINER_DEF = {
181+
"ModelPackageName": MODEL_PACKAGE_ARN,
182+
}
183+
184+
IR_TAGS = [
185+
{
186+
"Key": "ClientType",
187+
"Value": "PythonSDK-RightSize",
188+
}
189+
]
190+
180191

181192
@pytest.fixture()
182193
def sagemaker_session():
@@ -371,6 +382,8 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
371382
framework=IR_SAMPLE_FRAMEWORK,
372383
)
373384

385+
sagemaker_session.create_model.assert_not_called()
386+
374387
# assert that the create api has been called with default parameters
375388
sagemaker_session.create_inference_recommendations_job.assert_called_with(
376389
role=IR_ROLE_ARN,
@@ -426,6 +439,8 @@ def test_right_size_advanced_list_instances_model_package_successful(
426439
max_parallel_tests=5,
427440
)
428441

442+
sagemaker_session.create_model.assert_not_called()
443+
429444
# assert that the create api has been called with advanced parameters
430445
sagemaker_session.create_inference_recommendations_job.assert_called_with(
431446
role=IR_ROLE_ARN,
@@ -481,6 +496,8 @@ def test_right_size_advanced_single_instances_model_package_successful(
481496
max_parallel_tests=5,
482497
)
483498

499+
sagemaker_session.create_model.assert_not_called()
500+
484501
# assert that the create api has been called with advanced parameters
485502
sagemaker_session.create_inference_recommendations_job.assert_called_with(
486503
role=IR_ROLE_ARN,
@@ -517,6 +534,8 @@ def test_right_size_advanced_model_package_partial_params_successful(
517534
model_latency_thresholds=IR_SAMPLE_MODEL_LATENCY_THRESHOLDS,
518535
)
519536

537+
sagemaker_session.create_model.assert_not_called()
538+
520539
# assert that the create api has been called with advanced parameters
521540
sagemaker_session.create_inference_recommendations_job.assert_called_with(
522541
role=IR_ROLE_ARN,
@@ -567,14 +586,23 @@ def test_deploy_right_size_with_model_package_succeeds(
567586
default_right_sized_model.name = MODEL_NAME
568587
default_right_sized_model.deploy(endpoint_name=IR_DEPLOY_ENDPOINT_NAME)
569588

589+
sagemaker_session.create_model.assert_called_with(
590+
MODEL_NAME,
591+
IR_ROLE_ARN,
592+
IR_MODEL_PACKAGE_CONTAINER_DEF,
593+
vpc_config=None,
594+
enable_network_isolation=False,
595+
tags=IR_TAGS,
596+
)
597+
570598
sagemaker_session.endpoint_from_production_variants.assert_called_with(
571599
async_inference_config_dict=None,
572600
data_capture_config_dict=None,
573601
explainer_config_dict=None,
574602
kms_key=None,
575603
name="ir-endpoint-test",
576604
production_variants=IR_PRODUCTION_VARIANTS,
577-
tags=None,
605+
tags=IR_TAGS,
578606
wait=True,
579607
)
580608

@@ -589,6 +617,15 @@ def test_deploy_right_size_with_both_overrides_succeeds(
589617
endpoint_name=IR_DEPLOY_ENDPOINT_NAME,
590618
)
591619

620+
sagemaker_session.create_model.assert_called_with(
621+
MODEL_NAME,
622+
IR_ROLE_ARN,
623+
IR_MODEL_PACKAGE_CONTAINER_DEF,
624+
vpc_config=None,
625+
enable_network_isolation=False,
626+
tags=None,
627+
)
628+
592629
sagemaker_session.endpoint_from_production_variants.assert_called_with(
593630
async_inference_config_dict=None,
594631
data_capture_config_dict=None,
@@ -637,6 +674,15 @@ def test_deploy_right_size_serverless_override(sagemaker_session, default_right_
637674
serverless_inference_config = ServerlessInferenceConfig()
638675
default_right_sized_model.deploy(serverless_inference_config=serverless_inference_config)
639676

677+
sagemaker_session.create_model.assert_called_with(
678+
MODEL_NAME,
679+
IR_ROLE_ARN,
680+
IR_MODEL_PACKAGE_CONTAINER_DEF,
681+
vpc_config=None,
682+
enable_network_isolation=False,
683+
tags=None,
684+
)
685+
640686
sagemaker_session.endpoint_from_production_variants.assert_called_with(
641687
name=MODEL_NAME,
642688
production_variants=IR_SERVERLESS_PRODUCTION_VARIANTS,
@@ -661,6 +707,15 @@ def test_deploy_right_size_async_override(sagemaker_session, default_right_sized
661707
async_inference_config=async_inference_config,
662708
)
663709

710+
sagemaker_session.create_model.assert_called_with(
711+
MODEL_NAME,
712+
IR_ROLE_ARN,
713+
IR_MODEL_PACKAGE_CONTAINER_DEF,
714+
vpc_config=None,
715+
enable_network_isolation=False,
716+
tags=None,
717+
)
718+
664719
sagemaker_session.endpoint_from_production_variants.assert_called_with(
665720
name=MODEL_NAME,
666721
production_variants=[ANY],
@@ -695,6 +750,15 @@ def test_deploy_right_size_explainer_config_override(sagemaker_session, default_
695750
explainer_config=explainer_config,
696751
)
697752

753+
sagemaker_session.create_model.assert_called_with(
754+
MODEL_NAME,
755+
IR_ROLE_ARN,
756+
IR_MODEL_PACKAGE_CONTAINER_DEF,
757+
vpc_config=None,
758+
enable_network_isolation=False,
759+
tags=None,
760+
)
761+
698762
sagemaker_session.endpoint_from_production_variants.assert_called_with(
699763
name=MODEL_NAME,
700764
production_variants=[ANY],

0 commit comments

Comments
 (0)