Skip to content

Commit ecb60fc

Browse files
author
Raymond Liu
committed
address comments in PR. Improve unit tests
1 parent 7c524dd commit ecb60fc

File tree

4 files changed

+176
-33
lines changed

4 files changed

+176
-33
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,25 +148,32 @@ def right_size(
148148

149149
self._init_sagemaker_session_if_does_not_exist()
150150

151-
model_name = None
152-
if isinstance(self, sagemaker.model.FrameworkModel):
151+
temp_model_name = None
152+
if isinstance(self, sagemaker.model.Model) and not isinstance(self, sagemaker.model.ModelPackage):
153153

154154
unique_tail = uuid.uuid4()
155-
model_name = "SMPYTHONSDK-" + str(unique_tail)
155+
temp_model_name = "SMPYTHONSDK-" + str(unique_tail)
156156

157-
self.sagemaker_session.create_model(
158-
name=model_name,
157+
create_model_args = dict(
158+
name=temp_model_name,
159159
role=self.role,
160160
container_defs=None,
161161
primary_container=self.prepare_container_def(),
162+
vpc_config=self.vpc_config,
163+
enable_network_isolation=self.enable_network_isolation()
162164
)
165+
print(f"Creating temporary model with name: {temp_model_name}" \
166+
"for Inference Recommender.", flush=True)
167+
self.sagemaker_session.create_model(**create_model_args)
168+
print("Temporary model created. Start to run Inference Recommender...", flush=True)
169+
163170

164171
ret_name = self.sagemaker_session.create_inference_recommendations_job(
165172
role=self.role,
166173
job_name=job_name,
167174
job_type=job_type,
168175
job_duration_in_seconds=job_duration_in_seconds,
169-
model_name=model_name,
176+
model_name=temp_model_name,
170177
model_package_version_arn=getattr(self, "model_package_arn", None),
171178
framework=framework,
172179
framework_version=framework_version,
@@ -188,8 +195,11 @@ def right_size(
188195
"InferenceRecommendations"
189196
)
190197

191-
if model_name is not None:
192-
self.sagemaker_session.delete_model(model_name)
198+
if temp_model_name is not None:
199+
print(f"Deleting temporary model with name: {temp_model_name}" \
200+
"for Inference Recommender.", flush=True)
201+
self.sagemaker_session.delete_model(temp_model_name)
202+
print("Delete complete.")
193203
return self
194204

195205
def _update_params(

src/sagemaker/session.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4886,26 +4886,20 @@ def _create_inference_recommendations_job_request(
48864886
if supported_instance_types:
48874887
containerConfig["SupportedInstanceTypes"] = supported_instance_types
48884888

4889-
if model_package_version_arn:
4890-
request = {
4891-
"JobName": job_name,
4892-
"JobType": job_type,
4893-
"RoleArn": role,
4894-
"InputConfig": {
4895-
"ContainerConfig": containerConfig,
4896-
"ModelPackageVersionArn": model_package_version_arn,
4897-
},
4898-
}
4899-
else:
4900-
request = {
4901-
"JobName": job_name,
4902-
"JobType": job_type,
4903-
"RoleArn": role,
4904-
"InputConfig": {
4905-
"ContainerConfig": containerConfig,
4906-
"ModelName": model_name,
4907-
},
4908-
}
4889+
request = {
4890+
"JobName": job_name,
4891+
"JobType": job_type,
4892+
"RoleArn": role,
4893+
"InputConfig": {
4894+
"ContainerConfig": containerConfig,
4895+
},
4896+
}
4897+
4898+
request.get("InputConfig").update(
4899+
{ "ModelPackageVersionArn": model_package_version_arn}
4900+
if model_package_version_arn
4901+
else { "ModelName": model_name }
4902+
)
49094903

49104904
if job_description:
49114905
request["JobDescription"] = job_description
@@ -4980,7 +4974,12 @@ def create_inference_recommendations_job(
49804974
"""
49814975

49824976
if model_name is None and model_package_version_arn is None:
4983-
raise ValueError("Either model_name or model_package_version_arn should be provided.")
4977+
raise ValueError("Missing model_name and model_package_version_arn,"\
4978+
" please provide one of them.")
4979+
4980+
if model_name is not None and model_package_version_arn is not None:
4981+
raise ValueError("Please provide either model_name or model_package_version_arn" \
4982+
" should be provided, not both.")
49844983

49854984
if not job_name:
49864985
unique_tail = uuid.uuid4()

tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
IR_SUPPORTED_CONTENT_TYPES = ["text/csv"]
2727
IR_JOB_NAME = "SMPYTHONSDK-1234567891"
2828
IR_SAMPLE_INSTANCE_TYPE = "ml.c5.xlarge"
29+
IR_MODEL_NAME = "SMPYTHONSDK-sample-unique-uuid"
2930

3031
IR_SAMPLE_LIST_OF_INSTANCES_HYPERPARAMETER_RANGES = [
3132
{
@@ -174,7 +175,7 @@ def default_right_sized_model(model_package):
174175
framework=IR_SAMPLE_FRAMEWORK,
175176
)
176177

177-
178+
@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid"))
178179
def test_right_size_default_with_model_name_successful(sagemaker_session, model):
179180
inference_recommender_model = model.right_size(
180181
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
@@ -184,13 +185,23 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
184185
framework=IR_SAMPLE_FRAMEWORK,
185186
)
186187

188+
# assert that the create model api has been called with default parameters
189+
assert sagemaker_session.create_model.called_with(
190+
name=IR_MODEL_NAME,
191+
role=IR_ROLE_ARN,
192+
container_defs=None,
193+
primary_container={},
194+
vpc_config=None,
195+
enable_network_isolation=False
196+
)
197+
187198
# assert that the create api has been called with default parameters with model name
188199
assert sagemaker_session.create_inference_recommendations_job.called_with(
189200
role=IR_ROLE_ARN,
190201
job_name=IR_JOB_NAME,
191202
job_type="Default",
192203
job_duration_in_seconds=None,
193-
model_name=ANY,
204+
model_name=IR_MODEL_NAME,
194205
model_package_version_arn=None,
195206
framework=IR_SAMPLE_FRAMEWORK,
196207
framework_version=None,
@@ -218,6 +229,7 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
218229
# confirm that the returned object of right_size is itself
219230
assert inference_recommender_model == model
220231

232+
@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid"))
221233
def test_right_size_advanced_list_instances_model_name_successful(sagemaker_session, model):
222234
inference_recommender_model = model.right_size(
223235
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
@@ -239,7 +251,7 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
239251
job_name=IR_JOB_NAME,
240252
job_type="Advanced",
241253
job_duration_in_seconds=7200,
242-
model_name=ANY,
254+
model_name=IR_MODEL_NAME,
243255
model_package_version_arn=None,
244256
framework=IR_SAMPLE_FRAMEWORK,
245257
framework_version=None,
@@ -267,6 +279,7 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
267279
# confirm that the returned object of right_size is itself
268280
assert inference_recommender_model == model
269281

282+
@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid"))
270283
def test_right_size_advanced_single_instances_model_name_successful(sagemaker_session, model):
271284
model.right_size(
272285
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
@@ -288,7 +301,7 @@ def test_right_size_advanced_single_instances_model_name_successful(sagemaker_se
288301
job_name=IR_JOB_NAME,
289302
job_type="Advanced",
290303
job_duration_in_seconds=7200,
291-
model_name=ANY,
304+
model_name=IR_MODEL_NAME,
292305
model_package_version_arn=None,
293306
framework=IR_SAMPLE_FRAMEWORK,
294307
framework_version=None,

tests/unit/test_session.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3194,6 +3194,7 @@ def test_batch_get_record(sagemaker_session):
31943194
IR_MODEL_PACKAGE_VERSION_ARN = (
31953195
"arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
31963196
)
3197+
IR_MODEL_NAME = "MODEL_NAME"
31973198
IR_NEAREST_MODEL_NAME = "xgboost"
31983199
IR_SUPPORTED_INSTANCE_TYPES = ["ml.c5.xlarge", "ml.c5.2xlarge"]
31993200
IR_FRAMEWORK = "XGBOOST"
@@ -3243,6 +3244,29 @@ def create_inference_recommendations_job_default_happy_response():
32433244
"JobDescription": "#python-sdk-create",
32443245
}
32453246

3247+
def create_inference_recommendations_job_default_model_name_happy_response():
3248+
return {
3249+
"JobName": IR_USER_JOB_NAME,
3250+
"JobType": "Default",
3251+
"RoleArn": IR_ROLE_ARN,
3252+
"InputConfig": {
3253+
"ContainerConfig": {
3254+
"Domain": "MACHINE_LEARNING",
3255+
"Task": "OTHER",
3256+
"Framework": IR_FRAMEWORK,
3257+
"PayloadConfig": {
3258+
"SamplePayloadUrl": IR_SAMPLE_PAYLOAD_URL,
3259+
"SupportedContentTypes": IR_SUPPORTED_CONTENT_TYPES,
3260+
},
3261+
"FrameworkVersion": IR_FRAMEWORK_VERSION,
3262+
"NearestModelName": IR_NEAREST_MODEL_NAME,
3263+
"SupportedInstanceTypes": IR_SUPPORTED_INSTANCE_TYPES,
3264+
},
3265+
"ModelName": IR_MODEL_NAME,
3266+
},
3267+
"JobDescription": "#python-sdk-create",
3268+
}
3269+
32463270

32473271
def create_inference_recommendations_job_advanced_happy_response():
32483272
base_advanced_job_response = create_inference_recommendations_job_default_happy_response()
@@ -3258,6 +3282,20 @@ def create_inference_recommendations_job_advanced_happy_response():
32583282
return base_advanced_job_response
32593283

32603284

3285+
def create_inference_recommendations_job_advanced_model_name_happy_response():
3286+
base_advanced_job_response = create_inference_recommendations_job_default_model_name_happy_response()
3287+
3288+
base_advanced_job_response["JobName"] = IR_JOB_NAME
3289+
base_advanced_job_response["JobType"] = IR_ADVANCED_JOB
3290+
base_advanced_job_response["StoppingConditions"] = IR_STOPPING_CONDITIONS
3291+
base_advanced_job_response["InputConfig"]["JobDurationInSeconds"] = IR_JOB_DURATION_IN_SECONDS
3292+
base_advanced_job_response["InputConfig"]["EndpointConfigurations"] = IR_ENDPOINT_CONFIGURATIONS
3293+
base_advanced_job_response["InputConfig"]["TrafficPattern"] = IR_TRAFFIC_PATTERN
3294+
base_advanced_job_response["InputConfig"]["ResourceLimit"] = IR_RESOURCE_LIMIT
3295+
3296+
return base_advanced_job_response
3297+
3298+
32613299
def test_create_inference_recommendations_job_default_happy(sagemaker_session):
32623300
job_name = sagemaker_session.create_inference_recommendations_job(
32633301
role=IR_ROLE_ARN,
@@ -3304,6 +3342,89 @@ def test_create_inference_recommendations_job_advanced_happy(sagemaker_session):
33043342
assert IR_JOB_NAME == job_name
33053343

33063344

3345+
def test_create_inference_recommendations_job_default_model_name_happy(sagemaker_session):
3346+
job_name = sagemaker_session.create_inference_recommendations_job(
3347+
role=IR_ROLE_ARN,
3348+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
3349+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
3350+
model_name = IR_MODEL_NAME,
3351+
model_package_version_arn=None,
3352+
framework=IR_FRAMEWORK,
3353+
framework_version=IR_FRAMEWORK_VERSION,
3354+
nearest_model_name=IR_NEAREST_MODEL_NAME,
3355+
supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES,
3356+
job_name=IR_USER_JOB_NAME,
3357+
)
3358+
3359+
sagemaker_session.sagemaker_client.create_inference_recommendations_job.assert_called_with(
3360+
**create_inference_recommendations_job_default_model_name_happy_response()
3361+
)
3362+
3363+
assert IR_USER_JOB_NAME == job_name
3364+
3365+
@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid"))
3366+
def test_create_inference_recommendations_job_advanced_model_name_happy(sagemaker_session):
3367+
job_name = sagemaker_session.create_inference_recommendations_job(
3368+
role=IR_ROLE_ARN,
3369+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
3370+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
3371+
model_name=IR_MODEL_NAME,
3372+
model_package_version_arn=None,
3373+
framework=IR_FRAMEWORK,
3374+
framework_version=IR_FRAMEWORK_VERSION,
3375+
nearest_model_name=IR_NEAREST_MODEL_NAME,
3376+
supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES,
3377+
endpoint_configurations=IR_ENDPOINT_CONFIGURATIONS,
3378+
traffic_pattern=IR_TRAFFIC_PATTERN,
3379+
stopping_conditions=IR_STOPPING_CONDITIONS,
3380+
resource_limit=IR_RESOURCE_LIMIT,
3381+
job_type=IR_ADVANCED_JOB,
3382+
job_duration_in_seconds=IR_JOB_DURATION_IN_SECONDS,
3383+
)
3384+
3385+
sagemaker_session.sagemaker_client.create_inference_recommendations_job.assert_called_with(
3386+
**create_inference_recommendations_job_advanced_model_name_happy_response()
3387+
)
3388+
3389+
assert IR_JOB_NAME == job_name
3390+
3391+
def test_create_inference_recommendations_job_missing_model_name_and_pkg(sagemaker_session):
3392+
with pytest.raises(
3393+
ValueError,
3394+
match="Missing model_name and model_package_version_arn, please provide one of them."
3395+
):
3396+
sagemaker_session.create_inference_recommendations_job(
3397+
role=IR_ROLE_ARN,
3398+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
3399+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
3400+
model_name = None,
3401+
model_package_version_arn=None,
3402+
framework=IR_FRAMEWORK,
3403+
framework_version=IR_FRAMEWORK_VERSION,
3404+
nearest_model_name=IR_NEAREST_MODEL_NAME,
3405+
supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES,
3406+
job_name=IR_USER_JOB_NAME,
3407+
)
3408+
3409+
def test_create_inference_recommendations_job_provided_model_name_and_pkg(sagemaker_session):
3410+
with pytest.raises(
3411+
ValueError,
3412+
match="Please provide either model_name or model_package_version_arn should be provided, not both."
3413+
):
3414+
sagemaker_session.create_inference_recommendations_job(
3415+
role=IR_ROLE_ARN,
3416+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
3417+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
3418+
model_name=IR_MODEL_NAME,
3419+
model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN,
3420+
framework=IR_FRAMEWORK,
3421+
framework_version=IR_FRAMEWORK_VERSION,
3422+
nearest_model_name=IR_NEAREST_MODEL_NAME,
3423+
supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES,
3424+
job_name=IR_USER_JOB_NAME,
3425+
)
3426+
3427+
33073428
def test_create_inference_recommendations_job_propogate_validation_exception(sagemaker_session):
33083429
validation_exception_message = (
33093430
"Failed to describe model due to validation failure with following error: test_error"

0 commit comments

Comments
 (0)