Skip to content

Commit 134ab56

Browse files
akrishna1995Ashwin Krishna
authored andcommitted
feature: adding resourcekey and tags for api in config for intelligent defaults
Adding ResourceKey to the Config schema for the createEdgePackaging Job API Adding a new entry into the Config Schema for CreateEndpoint API with the addition of tags attribute
1 parent a5719f8 commit 134ab56

File tree

7 files changed

+335
-120
lines changed

7 files changed

+335
-120
lines changed

src/sagemaker/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
COMPILATION_JOB_VPC_CONFIG_PATH,
4949
COMPILATION_JOB,
5050
EDGE_PACKAGING_ROLE_ARN_PATH,
51+
EDGE_PACKAGING_RESOURCE_KEY_PATH,
5152
EDGE_PACKAGING_OUTPUT_CONFIG_PATH,
5253
EDGE_PACKAGING_JOB,
5354
TRANSFORM_JOB,
@@ -69,10 +70,13 @@
6970
MODEL_PRIMARY_CONTAINER_ENVIRONMENT_PATH,
7071
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
7172
KMS_KEY_ID,
73+
RESOURCE_KEY,
7274
ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
7375
ENDPOINT_CONFIG,
7476
ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
7577
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
78+
ENDPOINT,
79+
ENDPOINT_TAGS_PATH,
7680
SAGEMAKER,
7781
FEATURE_GROUP,
7882
TAGS,

src/sagemaker/config/config_schema.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ENABLE_NETWORK_ISOLATION = "EnableNetworkIsolation"
1919
VOLUME_KMS_KEY_ID = "VolumeKmsKeyId"
2020
KMS_KEY_ID = "KmsKeyId"
21+
RESOURCE_KEY = "ResourceKey"
2122
ROLE_ARN = "RoleArn"
2223
TAGS = "Tags"
2324
KEY = "Key"
@@ -78,6 +79,7 @@
7879
MODEL = "Model"
7980
MONITORING_SCHEDULE = "MonitoringSchedule"
8081
ENDPOINT_CONFIG = "EndpointConfig"
82+
ENDPOINT = "Endpoint"
8183
AUTO_ML_JOB = "AutoMLJob"
8284
COMPILATION_JOB = "CompilationJob"
8385
CUSTOM_PARAMETERS = "CustomParameters"
@@ -131,6 +133,7 @@ def _simple_path(*args: str):
131133
)
132134
EDGE_PACKAGING_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG)
133135
EDGE_PACKAGING_ROLE_ARN_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, ROLE_ARN)
136+
EDGE_PACKAGING_RESOURCE_KEY_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, RESOURCE_KEY)
134137
ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH = _simple_path(
135138
SAGEMAKER, ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG, KMS_KEY_ID
136139
)
@@ -145,6 +148,7 @@ def _simple_path(*args: str):
145148
SAGEMAKER, ENDPOINT_CONFIG, ASYNC_INFERENCE_CONFIG, OUTPUT_CONFIG, KMS_KEY_ID
146149
)
147150
ENDPOINT_CONFIG_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, ENDPOINT_CONFIG, KMS_KEY_ID)
151+
ENDPOINT_TAGS_PATH = _simple_path(SAGEMAKER, ENDPOINT, TAGS)
148152
FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH = _simple_path(SAGEMAKER, FEATURE_GROUP, ONLINE_STORE_CONFIG)
149153
FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH = _simple_path(
150154
SAGEMAKER, FEATURE_GROUP, OFFLINE_STORE_CONFIG
@@ -746,6 +750,15 @@ def _simple_path(*args: str):
746750
TAGS: {"$ref": "#/definitions/tags"},
747751
},
748752
},
753+
# Endpoint
754+
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpoint.html
755+
ENDPOINT: {
756+
TYPE: OBJECT,
757+
ADDITIONAL_PROPERTIES: False,
758+
PROPERTIES: {
759+
TAGS: {"$ref": "#/definitions/tags"}
760+
}
761+
},
749762
# Endpoint Config
750763
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html
751764
# Note: there is a separate API for creating Endpoints.
@@ -992,6 +1005,7 @@ def _simple_path(*args: str):
9921005
ADDITIONAL_PROPERTIES: False,
9931006
PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}},
9941007
},
1008+
RESOURCE_KEY: {"$ref": "#/definitions/kmsKeyId"},
9951009
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
9961010
TAGS: {"$ref": "#/definitions/tags"},
9971011
},

src/sagemaker/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
EDGE_PACKAGING_KMS_KEY_ID_PATH,
3535
EDGE_PACKAGING_ROLE_ARN_PATH,
3636
MODEL_CONTAINERS_PATH,
37+
EDGE_PACKAGING_RESOURCE_KEY_PATH,
3738
MODEL_VPC_CONFIG_PATH,
3839
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
3940
MODEL_EXECUTION_ROLE_ARN_PATH,
@@ -959,6 +960,7 @@ def package_for_edge(
959960
role = resolve_value_from_config(
960961
role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
961962
)
963+
resource_key = resolve_value_from_config(resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self)
962964
if role is not None:
963965
role = self.sagemaker_session.expand_role(role)
964966
config = self._edge_packaging_job_config(

src/sagemaker/session.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
COMPILATION_JOB,
7171
EDGE_PACKAGING_ROLE_ARN_PATH,
7272
EDGE_PACKAGING_OUTPUT_CONFIG_PATH,
73+
EDGE_PACKAGING_RESOURCE_KEY_PATH,
7374
EDGE_PACKAGING_JOB,
7475
TRANSFORM_JOB,
7576
TRANSFORM_JOB_ENVIRONMENT_PATH,
@@ -95,6 +96,7 @@
9596
ENDPOINT_CONFIG,
9697
ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
9798
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
99+
ENDPOINT_TAGS_PATH,
98100
SAGEMAKER,
99101
FEATURE_GROUP,
100102
TAGS,
@@ -2503,6 +2505,7 @@ def package_model_for_edge(
25032505
"EdgePackagingJobName": job_name,
25042506
"CompilationJobName": compilation_job_name,
25052507
}
2508+
resource_key = resolve_value_from_config(resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self)
25062509
tags = _append_project_tags(tags)
25072510
tags = self._append_sagemaker_config_tags(
25082511
tags, "{}.{}.{}".format(SAGEMAKER, EDGE_PACKAGING_JOB, TAGS)
@@ -4018,6 +4021,8 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
40184021
config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy.
40194022
wait (bool): Whether to wait for the endpoint deployment to complete before returning
40204023
(default: True).
4024+
tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint
4025+
(default: None).
40214026
40224027
Returns:
40234028
str: Name of the Amazon SageMaker ``Endpoint`` created.
@@ -4026,6 +4031,9 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
40264031

40274032
tags = tags or []
40284033
tags = _append_project_tags(tags)
4034+
tags = self._append_sagemaker_config_tags(
4035+
tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_TAGS_PATH, TAGS)
4036+
)
40294037

40304038
self.sagemaker_client.create_endpoint(
40314039
EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags
@@ -4481,6 +4489,7 @@ def endpoint_from_model_data(
44814489
model_vpc_config=None,
44824490
accelerator_type=None,
44834491
data_capture_config=None,
4492+
tags=None,
44844493
):
44854494
"""Create and deploy to an ``Endpoint`` using existing model data stored in S3.
44864495
@@ -4512,14 +4521,23 @@ def endpoint_from_model_data(
45124521
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
45134522
configuration related to Endpoint data capture for use with
45144523
Amazon SageMaker Model Monitoring. Default: None.
4524+
tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint
4525+
(default: None).
45154526
45164527
Returns:
45174528
str: Name of the ``Endpoint`` that is created.
45184529
"""
45194530
model_environment_vars = model_environment_vars or {}
45204531
name = name or name_from_image(image_uri)
45214532
model_vpc_config = vpc_utils.sanitize(model_vpc_config)
4522-
4533+
endpoint_config_tags = _append_project_tags(tags)
4534+
endpoint_tags = _append_project_tags(tags)
4535+
endpoint_tags = self._append_sagemaker_config_tags(
4536+
endpoint_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_TAGS_PATH, TAGS)
4537+
)
4538+
endpoint_config_tags = self._append_sagemaker_config_tags(
4539+
endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS)
4540+
)
45234541
primary_container = container_def(
45244542
image_uri=image_uri,
45254543
model_data_url=model_s3_location,
@@ -4542,12 +4560,13 @@ def endpoint_from_model_data(
45424560
instance_type=instance_type,
45434561
accelerator_type=accelerator_type,
45444562
data_capture_config_dict=data_capture_config_dict,
4563+
tags=endpoint_config_tags
45454564
)
45464565
)
45474566

45484567
# to make change backwards compatible
45494568
response = _create_resource(
4550-
lambda: self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
4569+
lambda: self.create_endpoint(endpoint_name=name, config_name=name, Tags=endpoint_tags, wait=wait)
45514570
)
45524571
if not response:
45534572
raise ValueError(
@@ -4612,12 +4631,18 @@ def endpoint_from_production_variants(
46124631
if supports_kms
46134632
else kms_key
46144633
)
4615-
tags = _append_project_tags(tags)
4616-
tags = self._append_sagemaker_config_tags(
4617-
tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS)
4634+
4635+
endpoint_config_tags = _append_project_tags(tags)
4636+
endpoint_tags = _append_project_tags(tags)
4637+
4638+
endpoint_tags = self._append_sagemaker_config_tags(
4639+
endpoint_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_TAGS_PATH, TAGS)
46184640
)
4619-
if tags:
4620-
config_options["Tags"] = tags
4641+
endpoint_config_tags = self._append_sagemaker_config_tags(
4642+
endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS)
4643+
)
4644+
if endpoint_config_tags:
4645+
config_options["Tags"] = endpoint_config_tags
46214646
if kms_key:
46224647
config_options["KmsKeyId"] = kms_key
46234648
if data_capture_config_dict is not None:
@@ -4638,7 +4663,7 @@ def endpoint_from_production_variants(
46384663
LOGGER.info("Creating endpoint-config with name %s", name)
46394664
self.sagemaker_client.create_endpoint_config(**config_options)
46404665

4641-
return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait)
4666+
return self.create_endpoint(endpoint_name=name, config_name=name, tags=endpoint_tags, wait=wait)
46424667

46434668
def expand_role(self, role):
46444669
"""Expand an IAM role name into an ARN.

tests/data/config/config.yaml

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ SageMaker:
5353
Subnets:
5454
- 'subnet-1234'
5555
RoleArn: 'arn:aws:iam::555555555555:role/IMRole'
56+
Endpoint:
57+
Tags:
58+
- Key: "tag1"
59+
Value: "tagValue1"
5660
EndpointConfig:
5761
AsyncInferenceConfig:
5862
OutputConfig:
@@ -171,4 +175,28 @@ SageMaker:
171175
EdgePackagingJob:
172176
OutputConfig:
173177
KmsKeyId: 'kmskeyid1'
174-
RoleArn: 'arn:aws:iam::555555555555:role/IMRole'
178+
RoleArn: 'arn:aws:iam::555555555555:role/IMRole'
179+
ResourceKey: 'kmskeyid1'
180+
PythonSDK:
181+
Modules:
182+
RemoteFunction:
183+
Dependencies: "./requirements.txt"
184+
EnvironmentVariables:
185+
"var1": "value1"
186+
"var2": "value2"
187+
ImageUri: "123456789012.dkr.ecr.us-west-2.amazonaws.com/myimage:latest"
188+
IncludeLocalWorkDir: true
189+
InstanceType: "ml.m5.xlarge"
190+
JobCondaEnvironment: "some_conda_env"
191+
RoleArn: "arn:aws:iam::555555555555:role/IMRole"
192+
S3KmsKeyId: "kmskeyid1"
193+
S3RootUri: "s3://my-bucket/key"
194+
Tags:
195+
- Key: "tag1"
196+
Value: "tagValue1"
197+
VolumeKmsKeyId: "kmskeyid2"
198+
VpcConfig:
199+
SecurityGroupIds:
200+
- 'sg123'
201+
Subnets:
202+
- 'subnet-1234'

tests/unit/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
COMPILATION_JOB,
4040
OUTPUT_CONFIG,
4141
EDGE_PACKAGING_JOB,
42+
RESOURCE_KEY,
43+
ENDPOINT,
4244
ENDPOINT_CONFIG,
4345
DATA_CAPTURE_CONFIG,
4446
PRODUCTION_VARIANTS,
@@ -145,6 +147,7 @@
145147
OUTPUT_CONFIG: {
146148
KMS_KEY_ID: "configKmsKeyId",
147149
},
150+
RESOURCE_KEY: "kmskeyid1",
148151
ROLE_ARN: "arn:aws:iam::111111111111:role/ConfigRole",
149152
TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}],
150153
},
@@ -173,6 +176,15 @@
173176
},
174177
}
175178

179+
SAGEMAKER_CONFIG_ENDPOINT = {
180+
SCHEMA_VERSION: "1.0",
181+
SAGEMAKER: {
182+
ENDPOINT: {
183+
TAGS: [{KEY: "some-tag", VALUE: "value-for-tag"}],
184+
}
185+
}
186+
}
187+
176188
SAGEMAKER_CONFIG_AUTO_ML = {
177189
SCHEMA_VERSION: "1.0",
178190
SAGEMAKER: {

0 commit comments

Comments
 (0)