Skip to content

Commit dacac74

Browse files
author
Ashwin Krishna
committed
fix: comments from PR
fixed comments on PR, added UTs, style changes
1 parent 72b00b6 commit dacac74

File tree

8 files changed

+151
-154
lines changed

8 files changed

+151
-154
lines changed

src/sagemaker/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ def package_for_edge(
990990
role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
991991
)
992992
resource_key = resolve_value_from_config(
993-
resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self
993+
resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self.sagemaker_session
994994
)
995995
if role is not None:
996996
role = self.sagemaker_session.expand_role(role)

src/sagemaker/session.py

Lines changed: 39 additions & 127 deletions
Large diffs are not rendered by default.

tests/data/config/config.yaml

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -177,26 +177,3 @@ SageMaker:
177177
KmsKeyId: 'kmskeyid1'
178178
RoleArn: 'arn:aws:iam::555555555555:role/IMRole'
179179
ResourceKey: 'kmskeyid1'
180-
PythonSDK:
181-
Modules:
182-
RemoteFunction:
183-
Dependencies: "./requirements.txt"
184-
EnvironmentVariables:
185-
"var1": "value1"
186-
"var2": "value2"
187-
ImageUri: "123456789012.dkr.ecr.us-west-2.amazonaws.com/myimage:latest"
188-
IncludeLocalWorkDir: true
189-
InstanceType: "ml.m5.xlarge"
190-
JobCondaEnvironment: "some_conda_env"
191-
RoleArn: "arn:aws:iam::555555555555:role/IMRole"
192-
S3KmsKeyId: "kmskeyid1"
193-
S3RootUri: "s3://my-bucket/key"
194-
Tags:
195-
- Key: "tag1"
196-
Value: "tagValue1"
197-
VolumeKmsKeyId: "kmskeyid2"
198-
VpcConfig:
199-
SecurityGroupIds:
200-
- 'sg123'
201-
Subnets:
202-
- 'subnet-1234'

tests/unit/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,31 @@
188188
},
189189
}
190190

191+
SAGEMAKER_CONFIG_ENDPOINT_ENDPOINT_CONFIG_COMBINED = {
192+
SCHEMA_VERSION: "1.0",
193+
SAGEMAKER: {
194+
ENDPOINT_CONFIG: {
195+
ASYNC_INFERENCE_CONFIG: {
196+
OUTPUT_CONFIG: {
197+
KMS_KEY_ID: "testOutputKmsKeyId",
198+
}
199+
},
200+
DATA_CAPTURE_CONFIG: {
201+
KMS_KEY_ID: "testDataCaptureKmsKeyId",
202+
},
203+
KMS_KEY_ID: "ConfigKmsKeyId",
204+
PRODUCTION_VARIANTS: [
205+
{"CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId"}},
206+
{"CoreDumpConfig": {"KmsKeyId": "testCoreKmsKeyId2"}},
207+
],
208+
TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}],
209+
},
210+
ENDPOINT: {
211+
TAGS: [{KEY: "some-tag1", VALUE: "value-for-tag1"}],
212+
},
213+
},
214+
}
215+
191216
SAGEMAKER_CONFIG_AUTO_ML = {
192217
SCHEMA_VERSION: "1.0",
193218
SAGEMAKER: {

tests/unit/sagemaker/config/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def valid_edge_packaging_config(valid_iam_role_arn):
7676
return {
7777
"OutputConfig": {"KmsKeyId": "kmskeyid1"},
7878
"RoleArn": valid_iam_role_arn,
79+
"ResourceKey": "kmskeyid1",
7980
}
8081

8182

@@ -191,6 +192,11 @@ def valid_endpointconfig_config():
191192
}
192193

193194

195+
@pytest.fixture()
196+
def valid_endpoint_config(valid_tags):
197+
return {"Tags": valid_tags}
198+
199+
194200
@pytest.fixture()
195201
def valid_monitoring_schedule_config(
196202
valid_iam_role_arn, valid_vpc_config, valid_environment_config
@@ -232,6 +238,7 @@ def valid_config_with_all_the_scopes(
232238
valid_session_config,
233239
valid_feature_group_config,
234240
valid_monitoring_schedule_config,
241+
valid_endpoint_config,
235242
valid_endpointconfig_config,
236243
valid_automl_config,
237244
valid_transform_job_config,
@@ -253,6 +260,7 @@ def valid_config_with_all_the_scopes(
253260
},
254261
"FeatureGroup": valid_feature_group_config,
255262
"MonitoringSchedule": valid_monitoring_schedule_config,
263+
"Endpoint": valid_endpoint_config,
256264
"EndpointConfig": valid_endpointconfig_config,
257265
"AutoMLJob": valid_automl_config,
258266
"TransformJob": valid_transform_job_config,

tests/unit/sagemaker/model/test_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232
_test_default_bucket_and_prefix_combinations,
3333
DEFAULT_S3_BUCKET_NAME,
3434
DEFAULT_S3_OBJECT_KEY_PREFIX_NAME,
35+
SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB,
3536
)
3637

3738
MODEL_DATA = "s3://bucket/model.tar.gz"
3839
MODEL_IMAGE = "mi"
40+
MODEL_VERSION = "1.0"
3941
TIMESTAMP = "2017-10-10-14-14-15"
4042
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
4143

@@ -921,3 +923,27 @@ def without_user_input(sess):
921923
),
922924
)
923925
assert actual == expected
926+
927+
928+
def test_package_for_edge_with_sagemaker_config_injection(sagemaker_session):
929+
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB
930+
sagemaker_session.wait_for_edge_packaging_job.return_value = {"ModelArtifact": "TestArtifact"}
931+
sagemaker_session.expand_role.return_value = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"][
932+
"EdgePackagingJob"
933+
]["RoleArn"]
934+
model = Model(MODEL_DATA, MODEL_IMAGE, name=MODEL_NAME, sagemaker_session=sagemaker_session)
935+
model._compilation_job_name = "compiledModel"
936+
model.package_for_edge(output_path="", model_name=MODEL_NAME, model_version=MODEL_VERSION)
937+
sagemaker_session.expand_role.assert_called_with(
938+
SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"]["RoleArn"]
939+
)
940+
sagemaker_session.package_model_for_edge.assert_called_with(
941+
compilation_job_name="compiledModel",
942+
job_name="packagingel",
943+
model_name=MODEL_NAME,
944+
model_version=MODEL_VERSION,
945+
output_model_config={"S3OutputLocation": "", "KmsKeyId": "configKmsKeyId"},
946+
resource_key="kmskeyid1",
947+
role=SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB["SageMaker"]["EdgePackagingJob"]["RoleArn"],
948+
tags=None,
949+
)

tests/unit/test_endpoint_from_model_data.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ def test_all_defaults_no_existing_entities(name_from_image_mock, sagemaker_sessi
7070
instance_type=INSTANCE_TYPE,
7171
accelerator_type=None,
7272
data_capture_config_dict=None,
73+
tags=None,
7374
)
7475
sagemaker_session.create_endpoint.assert_called_once_with(
75-
endpoint_name=NAME_FROM_IMAGE, config_name=NAME_FROM_IMAGE, wait=False
76+
endpoint_name=NAME_FROM_IMAGE, config_name=NAME_FROM_IMAGE, wait=False, tags=None
7677
)
7778
assert returned_name == NAME_FROM_IMAGE
7879

@@ -107,9 +108,10 @@ def test_no_defaults_no_existing_entities(name_from_image_mock, sagemaker_sessio
107108
instance_type=INSTANCE_TYPE,
108109
accelerator_type=ACCELERATOR_TYPE,
109110
data_capture_config_dict=None,
111+
tags=None,
110112
)
111113
sagemaker_session.create_endpoint.assert_called_once_with(
112-
endpoint_name=ENDPOINT_NAME, config_name=ENDPOINT_NAME, wait=False
114+
endpoint_name=ENDPOINT_NAME, config_name=ENDPOINT_NAME, wait=False, tags=None
113115
)
114116
assert returned_name == ENDPOINT_NAME
115117

@@ -146,9 +148,10 @@ def test_model_and_endpoint_config_exist(name_from_image_mock, sagemaker_session
146148
instance_type=INSTANCE_TYPE,
147149
accelerator_type=None,
148150
data_capture_config_dict=None,
151+
tags=None,
149152
)
150153
sagemaker_session.create_endpoint.assert_called_once_with(
151-
endpoint_name=NAME_FROM_IMAGE, config_name=NAME_FROM_IMAGE, wait=False
154+
endpoint_name=NAME_FROM_IMAGE, config_name=NAME_FROM_IMAGE, wait=False, tags=None
152155
)
153156

154157

@@ -182,6 +185,7 @@ def test_model_and_endpoint_config_raises_unexpected_error(name_from_image_mock,
182185
instance_type=INSTANCE_TYPE,
183186
accelerator_type=None,
184187
data_capture_config_dict=None,
188+
tags=None,
185189
)
186190
sagemaker_session.create_endpoint.assert_not_called()
187191

tests/unit/test_session.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
SAGEMAKER_CONFIG_COMPILATION_JOB,
4646
SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB,
4747
SAGEMAKER_CONFIG_ENDPOINT_CONFIG,
48+
SAGEMAKER_CONFIG_ENDPOINT_ENDPOINT_CONFIG_COMBINED,
4849
SAGEMAKER_CONFIG_ENDPOINT,
4950
SAGEMAKER_CONFIG_AUTO_ML,
5051
SAGEMAKER_CONFIG_MODEL_PACKAGE,
@@ -3298,6 +3299,50 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session):
32983299
)
32993300

33003301

3302+
def test_endpoint_from_production_variants_with_combined_sagemaker_config_injection_tags(
3303+
sagemaker_session,
3304+
):
3305+
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_ENDPOINT_ENDPOINT_CONFIG_COMBINED
3306+
3307+
ims = sagemaker_session
3308+
ims.sagemaker_client.describe_endpoint = Mock(return_value={"EndpointStatus": "InService"})
3309+
pvs = [
3310+
sagemaker.production_variant("A", "ml.p2.xlarge"),
3311+
sagemaker.production_variant("B", "p299.4096xlarge"),
3312+
]
3313+
ex = ClientError(
3314+
{
3315+
"Error": {
3316+
"Code": "ValidationException",
3317+
"Message": "Could not find your thing",
3318+
}
3319+
},
3320+
"b",
3321+
)
3322+
ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex)
3323+
expected_endpoint_tags = SAGEMAKER_CONFIG_ENDPOINT_ENDPOINT_CONFIG_COMBINED["SageMaker"][
3324+
"Endpoint"
3325+
]["Tags"]
3326+
expected_endpoint_config_tags = SAGEMAKER_CONFIG_ENDPOINT_ENDPOINT_CONFIG_COMBINED["SageMaker"][
3327+
"EndpointConfig"
3328+
]["Tags"]
3329+
expected_endpoint_config_kms_key_id = SAGEMAKER_CONFIG_ENDPOINT_ENDPOINT_CONFIG_COMBINED[
3330+
"SageMaker"
3331+
]["EndpointConfig"]["KmsKeyId"]
3332+
sagemaker_session.endpoint_from_production_variants("some-endpoint", pvs)
3333+
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(
3334+
EndpointConfigName="some-endpoint",
3335+
EndpointName="some-endpoint",
3336+
Tags=expected_endpoint_tags,
3337+
)
3338+
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
3339+
EndpointConfigName="some-endpoint",
3340+
ProductionVariants=pvs,
3341+
Tags=expected_endpoint_config_tags,
3342+
KmsKeyId=expected_endpoint_config_kms_key_id,
3343+
)
3344+
3345+
33013346
def test_endpoint_from_production_variants_with_sagemaker_config_injection_tags(
33023347
sagemaker_session,
33033348
):

0 commit comments

Comments
 (0)