Skip to content

Commit 2df49dd

Browse files
authored
feat: Add tagging support for create ir job (#3901)
Co-authored-by: Gary Wang 😤 <[email protected]>
1 parent 624cac8 commit 2df49dd

File tree

3 files changed

+77
-44
lines changed

3 files changed

+77
-44
lines changed

src/sagemaker/session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5313,6 +5313,7 @@ def _create_inference_recommendations_job_request(
53135313
framework: str,
53145314
sample_payload_url: str,
53155315
supported_content_types: List[str],
5316+
tags: Dict[str, str],
53165317
model_name: str = None,
53175318
model_package_version_arn: str = None,
53185319
job_duration_in_seconds: int = None,
@@ -5348,6 +5349,8 @@ def _create_inference_recommendations_job_request(
53485349
benchmarked by Amazon SageMaker Inference Recommender that matches your model.
53495350
supported_instance_types (List[str]): A list of the instance types that are used
53505351
to generate inferences in real-time.
5352+
tags (Dict[str, str]): Tags used to identify where the Inference Recommendatons Call
5353+
was made from.
53515354
endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
53525355
to use for a job. Will be used for `Advanced` jobs.
53535356
traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
@@ -5386,6 +5389,7 @@ def _create_inference_recommendations_job_request(
53865389
"InputConfig": {
53875390
"ContainerConfig": containerConfig,
53885391
},
5392+
"Tags": tags,
53895393
}
53905394

53915395
request.get("InputConfig").update(
@@ -5477,6 +5481,8 @@ def create_inference_recommendations_job(
54775481
job_name = "SMPYTHONSDK-" + str(unique_tail)
54785482
job_description = "#python-sdk-create"
54795483

5484+
tags = [{"Key": "ClientType", "Value": "PythonSDK-RightSize"}]
5485+
54805486
create_inference_recommendations_job_request = (
54815487
self._create_inference_recommendations_job_request(
54825488
role=role,
@@ -5496,6 +5502,7 @@ def create_inference_recommendations_job(
54965502
traffic_pattern=traffic_pattern,
54975503
stopping_conditions=stopping_conditions,
54985504
resource_limit=resource_limit,
5505+
tags=tags,
54995506
)
55005507
)
55015508

tests/integ/test_inference_recommender.py

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
from __future__ import absolute_import
1414

1515
import os
16+
import time
1617

1718
import pytest
1819

20+
from botocore.exceptions import ClientError
1921
from sagemaker import image_uris
2022
from sagemaker.model import Model
2123
from sagemaker.sklearn.model import SKLearnModel, SKLearnPredictor
@@ -40,6 +42,18 @@
4042
IR_SKLEARN_FRAMEWORK_VERSION = "1.0-1"
4143

4244

45+
def retry_and_back_off(right_size_fn):
46+
tot_retries = 3
47+
retries = 1
48+
while retries <= tot_retries:
49+
try:
50+
return right_size_fn
51+
except ClientError as e:
52+
if e.response["Error"]["Code"] == "ThrottlingException":
53+
retries += 1
54+
time.sleep(5 * retries)
55+
56+
4357
@pytest.fixture(scope="module")
4458
def default_right_sized_model(sagemaker_session, cpu_instance_type):
4559
with timeout(minutes=45):
@@ -68,13 +82,15 @@ def default_right_sized_model(sagemaker_session, cpu_instance_type):
6882
)
6983

7084
return (
71-
sklearn_model_package.right_size(
72-
job_name=ir_job_name,
73-
sample_payload_url=payload_data,
74-
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
75-
supported_instance_types=[cpu_instance_type],
76-
framework=IR_SKLEARN_FRAMEWORK,
77-
log_level="Quiet",
85+
retry_and_back_off(
86+
sklearn_model_package.right_size(
87+
job_name=ir_job_name,
88+
sample_payload_url=payload_data,
89+
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
90+
supported_instance_types=[cpu_instance_type],
91+
framework=IR_SKLEARN_FRAMEWORK,
92+
log_level="Quiet",
93+
)
7894
),
7995
model_package_group_name,
8096
ir_job_name,
@@ -133,17 +149,19 @@ def advanced_right_sized_model(sagemaker_session, cpu_instance_type):
133149
]
134150

135151
return (
136-
sklearn_model_package.right_size(
137-
sample_payload_url=payload_data,
138-
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
139-
framework=IR_SKLEARN_FRAMEWORK,
140-
job_duration_in_seconds=3600,
141-
hyperparameter_ranges=hyperparameter_ranges,
142-
phases=phases,
143-
model_latency_thresholds=model_latency_thresholds,
144-
max_invocations=100,
145-
max_tests=5,
146-
max_parallel_tests=5,
152+
retry_and_back_off(
153+
sklearn_model_package.right_size(
154+
sample_payload_url=payload_data,
155+
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
156+
framework=IR_SKLEARN_FRAMEWORK,
157+
job_duration_in_seconds=3600,
158+
hyperparameter_ranges=hyperparameter_ranges,
159+
phases=phases,
160+
model_latency_thresholds=model_latency_thresholds,
161+
max_invocations=100,
162+
max_tests=5,
163+
max_parallel_tests=5,
164+
)
147165
),
148166
model_package_group_name,
149167
)
@@ -175,13 +193,15 @@ def default_right_sized_unregistered_model(sagemaker_session, cpu_instance_type)
175193
)
176194

177195
return (
178-
sklearn_model.right_size(
179-
job_name=ir_job_name,
180-
sample_payload_url=payload_data,
181-
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
182-
supported_instance_types=[cpu_instance_type],
183-
framework=IR_SKLEARN_FRAMEWORK,
184-
log_level="Quiet",
196+
retry_and_back_off(
197+
sklearn_model.right_size(
198+
job_name=ir_job_name,
199+
sample_payload_url=payload_data,
200+
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
201+
supported_instance_types=[cpu_instance_type],
202+
framework=IR_SKLEARN_FRAMEWORK,
203+
log_level="Quiet",
204+
)
185205
),
186206
ir_job_name,
187207
)
@@ -224,18 +244,20 @@ def advanced_right_sized_unregistered_model(sagemaker_session, cpu_instance_type
224244
ModelLatencyThreshold(percentile="P95", value_in_milliseconds=100)
225245
]
226246

227-
return sklearn_model.right_size(
228-
sample_payload_url=payload_data,
229-
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
230-
framework=IR_SKLEARN_FRAMEWORK,
231-
job_duration_in_seconds=3600,
232-
hyperparameter_ranges=hyperparameter_ranges,
233-
phases=phases,
234-
model_latency_thresholds=model_latency_thresholds,
235-
max_invocations=100,
236-
max_tests=5,
237-
max_parallel_tests=5,
238-
log_level="Quiet",
247+
return retry_and_back_off(
248+
sklearn_model.right_size(
249+
sample_payload_url=payload_data,
250+
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
251+
framework=IR_SKLEARN_FRAMEWORK,
252+
job_duration_in_seconds=3600,
253+
hyperparameter_ranges=hyperparameter_ranges,
254+
phases=phases,
255+
model_latency_thresholds=model_latency_thresholds,
256+
max_invocations=100,
257+
max_tests=5,
258+
max_parallel_tests=5,
259+
log_level="Quiet",
260+
)
239261
)
240262

241263
except Exception:
@@ -265,13 +287,15 @@ def default_right_sized_unregistered_base_model(sagemaker_session, cpu_instance_
265287
)
266288

267289
return (
268-
model.right_size(
269-
job_name=ir_job_name,
270-
sample_payload_url=payload_data,
271-
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
272-
supported_instance_types=[cpu_instance_type],
273-
framework=IR_SKLEARN_FRAMEWORK,
274-
log_level="Quiet",
290+
retry_and_back_off(
291+
model.right_size(
292+
job_name=ir_job_name,
293+
sample_payload_url=payload_data,
294+
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
295+
supported_instance_types=[cpu_instance_type],
296+
framework=IR_SKLEARN_FRAMEWORK,
297+
log_level="Quiet",
298+
)
275299
),
276300
ir_job_name,
277301
)

tests/unit/test_session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4704,6 +4704,7 @@ def create_inference_recommendations_job_default_happy_response():
47044704
"ModelPackageVersionArn": IR_MODEL_PACKAGE_VERSION_ARN,
47054705
},
47064706
"JobDescription": "#python-sdk-create",
4707+
"Tags": [{"Key": "ClientType", "Value": "PythonSDK-RightSize"}],
47074708
}
47084709

47094710

@@ -4728,6 +4729,7 @@ def create_inference_recommendations_job_default_model_name_happy_response():
47284729
"ModelName": IR_MODEL_NAME,
47294730
},
47304731
"JobDescription": "#python-sdk-create",
4732+
"Tags": [{"Key": "ClientType", "Value": "PythonSDK-RightSize"}],
47314733
}
47324734

47334735

0 commit comments

Comments
 (0)