Skip to content

Commit 7c524dd

Browse files
author
Raymond Liu
committed
decouple right_size() from model registry
1 parent e913520 commit 7c524dd

File tree

3 files changed

+180
-32
lines changed

3 files changed

+180
-32
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import re
18+
import uuid
1819

1920
from typing import List, Dict, Optional
2021
import sagemaker
@@ -38,7 +39,7 @@ class Phase:
3839
"""
3940

4041
def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int):
41-
"""Initialze a `Phase`"""
42+
"""Initialize a `Phase`"""
4243
self.to_json = {
4344
"DurationInSeconds": duration_in_seconds,
4445
"InitialNumberOfUsers": initial_number_of_users,
@@ -53,7 +54,7 @@ class ModelLatencyThreshold:
5354
"""
5455

5556
def __init__(self, percentile: str, value_in_milliseconds: int):
56-
"""Initialze a `ModelLatencyThreshold`"""
57+
"""Initialize a `ModelLatencyThreshold`"""
5758
self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds}
5859

5960

@@ -119,8 +120,6 @@ def right_size(
119120
sagemaker.model.Model: A SageMaker ``Model`` object. See
120121
:func:`~sagemaker.model.Model` for full details.
121122
"""
122-
if not isinstance(self, sagemaker.model.ModelPackage):
123-
raise ValueError("right_size() is currently only supported with a registered model")
124123

125124
if not framework and self._framework():
126125
framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING.get(self._framework(), framework)
@@ -149,12 +148,26 @@ def right_size(
149148

150149
self._init_sagemaker_session_if_does_not_exist()
151150

151+
model_name = None
152+
if isinstance(self, sagemaker.model.FrameworkModel):
153+
154+
unique_tail = uuid.uuid4()
155+
model_name = "SMPYTHONSDK-" + str(unique_tail)
156+
157+
self.sagemaker_session.create_model(
158+
name=model_name,
159+
role=self.role,
160+
container_defs=None,
161+
primary_container=self.prepare_container_def(),
162+
)
163+
152164
ret_name = self.sagemaker_session.create_inference_recommendations_job(
153165
role=self.role,
154166
job_name=job_name,
155167
job_type=job_type,
156168
job_duration_in_seconds=job_duration_in_seconds,
157-
model_package_version_arn=self.model_package_arn,
169+
model_name=model_name,
170+
model_package_version_arn=getattr(self, "model_package_arn", None),
158171
framework=framework,
159172
framework_version=framework_version,
160173
sample_payload_url=sample_payload_url,
@@ -175,6 +188,8 @@ def right_size(
175188
"InferenceRecommendations"
176189
)
177190

191+
if model_name is not None:
192+
self.sagemaker_session.delete_model(model_name)
178193
return self
179194

180195
def _update_params(

src/sagemaker/session.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4820,6 +4820,7 @@ def _create_inference_recommendations_job_request(
48204820
framework: str,
48214821
sample_payload_url: str,
48224822
supported_content_types: List[str],
4823+
model_name: str = None,
48234824
model_package_version_arn: str = None,
48244825
job_duration_in_seconds: int = None,
48254826
job_type: str = "Default",
@@ -4843,6 +4844,7 @@ def _create_inference_recommendations_job_request(
48434844
framework (str): The machine learning framework of the Image URI.
48444845
sample_payload_url (str): The S3 path where the sample payload is stored.
48454846
supported_content_types (List[str]): The supported MIME types for the input data.
4847+
model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
48464848
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
48474849
versioned model package.
48484850
job_duration_in_seconds (int): The maximum job duration that a job
@@ -4884,15 +4886,26 @@ def _create_inference_recommendations_job_request(
48844886
if supported_instance_types:
48854887
containerConfig["SupportedInstanceTypes"] = supported_instance_types
48864888

4887-
request = {
4888-
"JobName": job_name,
4889-
"JobType": job_type,
4890-
"RoleArn": role,
4891-
"InputConfig": {
4892-
"ContainerConfig": containerConfig,
4893-
"ModelPackageVersionArn": model_package_version_arn,
4894-
},
4895-
}
4889+
if model_package_version_arn:
4890+
request = {
4891+
"JobName": job_name,
4892+
"JobType": job_type,
4893+
"RoleArn": role,
4894+
"InputConfig": {
4895+
"ContainerConfig": containerConfig,
4896+
"ModelPackageVersionArn": model_package_version_arn,
4897+
},
4898+
}
4899+
else:
4900+
request = {
4901+
"JobName": job_name,
4902+
"JobType": job_type,
4903+
"RoleArn": role,
4904+
"InputConfig": {
4905+
"ContainerConfig": containerConfig,
4906+
"ModelName": model_name,
4907+
},
4908+
}
48964909

48974910
if job_description:
48984911
request["JobDescription"] = job_description
@@ -4918,6 +4931,7 @@ def create_inference_recommendations_job(
49184931
supported_content_types: List[str],
49194932
job_name: str = None,
49204933
job_type: str = "Default",
4934+
model_name: str = None,
49214935
model_package_version_arn: str = None,
49224936
job_duration_in_seconds: int = None,
49234937
nearest_model_name: str = None,
@@ -4938,6 +4952,7 @@ def create_inference_recommendations_job(
49384952
You must grant sufficient permissions to this role.
49394953
sample_payload_url (str): The S3 path where the sample payload is stored.
49404954
supported_content_types (List[str]): The supported MIME types for the input data.
4955+
model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
49414956
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
49424957
versioned model package.
49434958
job_name (str): The name of the job being run.
@@ -4964,6 +4979,9 @@ def create_inference_recommendations_job(
49644979
str: The name of the job created. In the form of `SMPYTHONSDK-<timestamp>`
49654980
"""
49664981

4982+
if model_name is None and model_package_version_arn is None:
4983+
raise ValueError("Either model_name or model_package_version_arn should be provided.")
4984+
49674985
if not job_name:
49684986
unique_tail = uuid.uuid4()
49694987
job_name = "SMPYTHONSDK-" + str(unique_tail)
@@ -4972,6 +4990,7 @@ def create_inference_recommendations_job(
49724990
create_inference_recommendations_job_request = (
49734991
self._create_inference_recommendations_job_request(
49744992
role=role,
4993+
model_name=model_name,
49754994
model_package_version_arn=model_package_version_arn,
49764995
job_name=job_name,
49774996
job_type=job_type,

tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py

Lines changed: 132 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,134 @@ def default_right_sized_model(model_package):
175175
)
176176

177177

178+
def test_right_size_default_with_model_name_successful(sagemaker_session, model):
179+
inference_recommender_model = model.right_size(
180+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
181+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
182+
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
183+
job_name=IR_JOB_NAME,
184+
framework=IR_SAMPLE_FRAMEWORK,
185+
)
186+
187+
# assert that the create api has been called with default parameters with model name
188+
assert sagemaker_session.create_inference_recommendations_job.called_with(
189+
role=IR_ROLE_ARN,
190+
job_name=IR_JOB_NAME,
191+
job_type="Default",
192+
job_duration_in_seconds=None,
193+
model_name=ANY,
194+
model_package_version_arn=None,
195+
framework=IR_SAMPLE_FRAMEWORK,
196+
framework_version=None,
197+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
198+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
199+
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
200+
endpoint_configurations=None,
201+
traffic_pattern=None,
202+
stopping_conditions=None,
203+
resource_limit=None,
204+
)
205+
206+
assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME)
207+
208+
# confirm that the IR instance attributes have been set
209+
assert (
210+
inference_recommender_model.inference_recommender_job_results
211+
== IR_SAMPLE_INFERENCE_RESPONSE
212+
)
213+
assert (
214+
inference_recommender_model.inference_recommendations
215+
== IR_SAMPLE_INFERENCE_RESPONSE["InferenceRecommendations"]
216+
)
217+
218+
# confirm that the returned object of right_size is itself
219+
assert inference_recommender_model == model
220+
221+
def test_right_size_advanced_list_instances_model_name_successful(sagemaker_session, model):
222+
inference_recommender_model = model.right_size(
223+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
224+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
225+
framework="SAGEMAKER-SCIKIT-LEARN",
226+
job_duration_in_seconds=7200,
227+
hyperparameter_ranges=IR_SAMPLE_LIST_OF_INSTANCES_HYPERPARAMETER_RANGES,
228+
phases=IR_SAMPLE_PHASES,
229+
traffic_type="PHASES",
230+
max_invocations=100,
231+
model_latency_thresholds=IR_SAMPLE_MODEL_LATENCY_THRESHOLDS,
232+
max_tests=5,
233+
max_parallel_tests=5,
234+
)
235+
236+
# assert that the create api has been called with advanced parameters
237+
assert sagemaker_session.create_inference_recommendations_job.called_with(
238+
role=IR_ROLE_ARN,
239+
job_name=IR_JOB_NAME,
240+
job_type="Advanced",
241+
job_duration_in_seconds=7200,
242+
model_name=ANY,
243+
model_package_version_arn=None,
244+
framework=IR_SAMPLE_FRAMEWORK,
245+
framework_version=None,
246+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
247+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
248+
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
249+
endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG,
250+
traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN,
251+
stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS,
252+
resource_limit=IR_SAMPLE_RESOURCE_LIMIT,
253+
)
254+
255+
assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME)
256+
257+
# confirm that the IR instance attributes have been set
258+
assert (
259+
inference_recommender_model.inference_recommender_job_results
260+
== IR_SAMPLE_INFERENCE_RESPONSE
261+
)
262+
assert (
263+
inference_recommender_model.inference_recommendations
264+
== IR_SAMPLE_INFERENCE_RESPONSE["InferenceRecommendations"]
265+
)
266+
267+
# confirm that the returned object of right_size is itself
268+
assert inference_recommender_model == model
269+
270+
def test_right_size_advanced_single_instances_model_name_successful(sagemaker_session, model):
271+
model.right_size(
272+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
273+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
274+
framework="SAGEMAKER-SCIKIT-LEARN",
275+
job_duration_in_seconds=7200,
276+
hyperparameter_ranges=IR_SAMPLE_SINGLE_INSTANCES_HYPERPARAMETER_RANGES,
277+
phases=IR_SAMPLE_PHASES,
278+
traffic_type="PHASES",
279+
max_invocations=100,
280+
model_latency_thresholds=IR_SAMPLE_MODEL_LATENCY_THRESHOLDS,
281+
max_tests=5,
282+
max_parallel_tests=5,
283+
)
284+
285+
# assert that the create api has been called with advanced parameters
286+
assert sagemaker_session.create_inference_recommendations_job.called_with(
287+
role=IR_ROLE_ARN,
288+
job_name=IR_JOB_NAME,
289+
job_type="Advanced",
290+
job_duration_in_seconds=7200,
291+
model_name=ANY,
292+
model_package_version_arn=None,
293+
framework=IR_SAMPLE_FRAMEWORK,
294+
framework_version=None,
295+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
296+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
297+
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
298+
endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG,
299+
traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN,
300+
stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS,
301+
resource_limit=IR_SAMPLE_RESOURCE_LIMIT,
302+
)
303+
304+
305+
178306
def test_right_size_default_with_model_package_successful(sagemaker_session, model_package):
179307
inference_recommender_model_pkg = model_package.right_size(
180308
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
@@ -190,6 +318,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
190318
job_name=IR_JOB_NAME,
191319
job_type="Default",
192320
job_duration_in_seconds=None,
321+
model_name=None,
193322
model_package_version_arn=model_package.model_package_arn,
194323
framework=IR_SAMPLE_FRAMEWORK,
195324
framework_version=None,
@@ -202,7 +331,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
202331
resource_limit=None,
203332
)
204333

205-
assert sagemaker_session.wait_for_inference_recomendations_job.called_with(IR_JOB_NAME)
334+
assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME)
206335

207336
# confirm that the IR instance attributes have been set
208337
assert (
@@ -216,7 +345,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
216345

217346
# confirm that the returned object of right_size is itself
218347
assert inference_recommender_model_pkg == model_package
219-
348+
220349

221350
def test_right_size_advanced_list_instances_model_package_successful(
222351
sagemaker_session, model_package
@@ -253,7 +382,7 @@ def test_right_size_advanced_list_instances_model_package_successful(
253382
resource_limit=IR_SAMPLE_RESOURCE_LIMIT,
254383
)
255384

256-
assert sagemaker_session.wait_for_inference_recomendations_job.called_with(IR_JOB_NAME)
385+
assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME)
257386

258387
# confirm that the IR instance attributes have been set
259388
assert (
@@ -359,21 +488,6 @@ def test_right_size_invalid_hyperparameter_ranges(sagemaker_session, model_packa
359488
)
360489

361490

362-
# TODO -> removed once model registry is decoupled
363-
def test_right_size_missing_model_package_arn(sagemaker_session, model):
364-
with pytest.raises(
365-
ValueError,
366-
match="right_size\\(\\) is currently only supported with a registered model",
367-
):
368-
model.right_size(
369-
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
370-
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
371-
supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE],
372-
job_name=IR_JOB_NAME,
373-
framework=IR_SAMPLE_FRAMEWORK,
374-
)
375-
376-
377491
# TODO check our framework mapping when we add in inference_recommendation_id support
378492

379493

0 commit comments

Comments
 (0)