Skip to content

Commit 1ed2d18

Browse files
akrishna1995Ashwin Krishna
authored andcommitted
fix: adding Unit tests for resourcekey and tags for api in config for intelligent defaults
Added Unit tests to test the config injection for SAGEMAKER_CONFIG_ENDPOINT Fixed a couple of unit tests
1 parent 134ab56 commit 1ed2d18

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

src/sagemaker/session.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
ENDPOINT_CONFIG,
9797
ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
9898
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
99+
ENDPOINT,
99100
ENDPOINT_TAGS_PATH,
100101
SAGEMAKER,
101102
FEATURE_GROUP,
@@ -4032,7 +4033,7 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
40324033
tags = tags or []
40334034
tags = _append_project_tags(tags)
40344035
tags = self._append_sagemaker_config_tags(
4035-
tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_TAGS_PATH, TAGS)
4036+
tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT, TAGS)
40364037
)
40374038

40384039
self.sagemaker_client.create_endpoint(
@@ -4532,9 +4533,6 @@ def endpoint_from_model_data(
45324533
model_vpc_config = vpc_utils.sanitize(model_vpc_config)
45334534
endpoint_config_tags = _append_project_tags(tags)
45344535
endpoint_tags = _append_project_tags(tags)
4535-
endpoint_tags = self._append_sagemaker_config_tags(
4536-
endpoint_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_TAGS_PATH, TAGS)
4537-
)
45384536
endpoint_config_tags = self._append_sagemaker_config_tags(
45394537
endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS)
45404538
)
@@ -4635,9 +4633,6 @@ def endpoint_from_production_variants(
46354633
endpoint_config_tags = _append_project_tags(tags)
46364634
endpoint_tags = _append_project_tags(tags)
46374635

4638-
endpoint_tags = self._append_sagemaker_config_tags(
4639-
endpoint_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_TAGS_PATH, TAGS)
4640-
)
46414636
endpoint_config_tags = self._append_sagemaker_config_tags(
46424637
endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS)
46434638
)

tests/unit/test_session.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2595,7 +2595,7 @@ def test_create_edge_packaging_with_sagemaker_config_injection(sagemaker_session
25952595
"OutputConfig"
25962596
]["KmsKeyId"]
25972597
expected_tags = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"]["Tags"]
2598-
expected_resource_key = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"]["ResourceKey"]
2598+
expected_resource_key = (SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"]["ResourceKey"],)
25992599
sagemaker_session.sagemaker_client.create_edge_packaging_job.assert_called_with(
26002600
RoleArn=expected_role_arn, # provided from config
26012601
OutputConfig={
@@ -3105,7 +3105,7 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection(
31053105
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(
31063106
EndpointConfigName="some-endpoint",
31073107
EndpointName="some-endpoint",
3108-
Tags=expected_tags, # from config
3108+
Tags=[],
31093109
)
31103110

31113111

@@ -3173,7 +3173,7 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection_parti
31733173
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(
31743174
EndpointConfigName="some-endpoint",
31753175
EndpointName="some-endpoint",
3176-
Tags=expected_tags, # from config
3176+
Tags=[],
31773177
)
31783178

31793179

@@ -3234,7 +3234,7 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection_no_km
32343234
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(
32353235
EndpointConfigName="some-endpoint",
32363236
EndpointName="some-endpoint",
3237-
Tags=expected_tags, # from config
3237+
Tags=[],
32383238
)
32393239

32403240

@@ -3315,7 +3315,7 @@ def test_endpoint_from_production_variants_with_sagemaker_config_injection_tags(
33153315
EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=expected_tags
33163316
)
33173317
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
3318-
EndpointConfigName="some-endpoint", ProductionVariants=pvs, Tags=[]
3318+
EndpointConfigName="some-endpoint", ProductionVariants=pvs
33193319
)
33203320

33213321

@@ -3366,11 +3366,12 @@ def test_endpoint_from_production_variants_with_accelerator_type_sagemaker_confi
33663366
)
33673367
ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex)
33683368
expected_tags = SAGEMAKER_CONFIG_ENDPOINT["SageMaker"]["Endpoint"]["Tags"]
3369+
sagemaker_session.endpoint_from_production_variants("some-endpoint", pvs)
33693370
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(
33703371
EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=expected_tags
33713372
)
33723373
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
3373-
EndpointConfigName="some-endpoint", ProductionVariants=pvs, Tags=[]
3374+
EndpointConfigName="some-endpoint", ProductionVariants=pvs
33743375
)
33753376

33763377

@@ -3442,7 +3443,7 @@ def test_endpoint_from_production_variants_with_serverless_inference_config_sage
34423443
EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=expected_tags
34433444
)
34443445
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
3445-
EndpointConfigName="some-endpoint", ProductionVariants=pvs, Tags=[]
3446+
EndpointConfigName="some-endpoint", ProductionVariants=pvs
34463447
)
34473448

34483449

0 commit comments

Comments
 (0)