Skip to content

Commit 1f38b07

Browse files
akrishna1995Ashwin Krishna
andauthored
feature: adding resourcekey and tags for api in config for SDK defaults (#3915)
* feature: adding resourcekey and tags for api in config for intelligent defaults Adding ResourceKey to the Config schema for the createEdgePackaging Job API Adding a new entry into the Config Schema for CreateEndpoint API with the addition of tags attribute * fix: adding Unit tests for resourcekey and tags for api in config for intelligent defaults Added Unit tests to test the config injection for SAGEMAKER_CONFIG_ENDPOINT Fixed a couple of unit tests * fix: SDK defaults formatting changes Making changes wrt formatting to fix tox issues * fix: comments from PR fixed comments on PR, added UTs, style changes --------- Co-authored-by: Ashwin Krishna <[email protected]>
1 parent 123f86a commit 1f38b07

File tree

10 files changed

+495
-54
lines changed

10 files changed

+495
-54
lines changed

src/sagemaker/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
COMPILATION_JOB_VPC_CONFIG_PATH,
4949
COMPILATION_JOB,
5050
EDGE_PACKAGING_ROLE_ARN_PATH,
51+
EDGE_PACKAGING_RESOURCE_KEY_PATH,
5152
EDGE_PACKAGING_OUTPUT_CONFIG_PATH,
5253
EDGE_PACKAGING_JOB,
5354
TRANSFORM_JOB,
@@ -69,10 +70,13 @@
6970
MODEL_PRIMARY_CONTAINER_ENVIRONMENT_PATH,
7071
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
7172
KMS_KEY_ID,
73+
RESOURCE_KEY,
7274
ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
7375
ENDPOINT_CONFIG,
7476
ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
7577
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
78+
ENDPOINT,
79+
ENDPOINT_TAGS_PATH,
7680
SAGEMAKER,
7781
FEATURE_GROUP,
7882
TAGS,

src/sagemaker/config/config_schema.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ENABLE_NETWORK_ISOLATION = "EnableNetworkIsolation"
1919
VOLUME_KMS_KEY_ID = "VolumeKmsKeyId"
2020
KMS_KEY_ID = "KmsKeyId"
21+
RESOURCE_KEY = "ResourceKey"
2122
ROLE_ARN = "RoleArn"
2223
TAGS = "Tags"
2324
KEY = "Key"
@@ -78,6 +79,7 @@
7879
MODEL = "Model"
7980
MONITORING_SCHEDULE = "MonitoringSchedule"
8081
ENDPOINT_CONFIG = "EndpointConfig"
82+
ENDPOINT = "Endpoint"
8183
AUTO_ML_JOB = "AutoMLJob"
8284
COMPILATION_JOB = "CompilationJob"
8385
CUSTOM_PARAMETERS = "CustomParameters"
@@ -131,6 +133,7 @@ def _simple_path(*args: str):
131133
)
132134
EDGE_PACKAGING_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG)
133135
EDGE_PACKAGING_ROLE_ARN_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, ROLE_ARN)
136+
EDGE_PACKAGING_RESOURCE_KEY_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, RESOURCE_KEY)
134137
ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH = _simple_path(
135138
SAGEMAKER, ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG, KMS_KEY_ID
136139
)
@@ -145,6 +148,7 @@ def _simple_path(*args: str):
145148
SAGEMAKER, ENDPOINT_CONFIG, ASYNC_INFERENCE_CONFIG, OUTPUT_CONFIG, KMS_KEY_ID
146149
)
147150
ENDPOINT_CONFIG_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, ENDPOINT_CONFIG, KMS_KEY_ID)
151+
ENDPOINT_TAGS_PATH = _simple_path(SAGEMAKER, ENDPOINT, TAGS)
148152
FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH = _simple_path(SAGEMAKER, FEATURE_GROUP, ONLINE_STORE_CONFIG)
149153
FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH = _simple_path(
150154
SAGEMAKER, FEATURE_GROUP, OFFLINE_STORE_CONFIG
@@ -167,14 +171,20 @@ def _simple_path(*args: str):
167171
)
168172
AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG)
169173
MONITORING_JOB_DEFINITION_PREFIX = _simple_path(
170-
SAGEMAKER, MONITORING_SCHEDULE, MONITORING_SCHEDULE_CONFIG, MONITORING_JOB_DEFINITION
174+
SAGEMAKER,
175+
MONITORING_SCHEDULE,
176+
MONITORING_SCHEDULE_CONFIG,
177+
MONITORING_JOB_DEFINITION,
171178
)
172179
MONITORING_JOB_ENVIRONMENT_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, ENVIRONMENT)
173180
MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH = _simple_path(
174181
MONITORING_JOB_DEFINITION_PREFIX, MONITORING_OUTPUT_CONFIG, KMS_KEY_ID
175182
)
176183
MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path(
177-
MONITORING_JOB_DEFINITION_PREFIX, MONITORING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID
184+
MONITORING_JOB_DEFINITION_PREFIX,
185+
MONITORING_RESOURCES,
186+
CLUSTER_CONFIG,
187+
VOLUME_KMS_KEY_ID,
178188
)
179189
MONITORING_JOB_NETWORK_CONFIG_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG)
180190
MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path(
@@ -288,7 +298,11 @@ def _simple_path(*args: str):
288298
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SECURITY_GROUP_IDS
289299
)
290300
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path(
291-
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
301+
SAGEMAKER,
302+
PYTHON_SDK,
303+
MODULES,
304+
REMOTE_FUNCTION,
305+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
292306
)
293307
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
294308
SAGEMAKER,
@@ -468,7 +482,11 @@ def _simple_path(*args: str):
468482
"maxProperties": 48,
469483
},
470484
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri
471-
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
485+
"s3Uri": {
486+
TYPE: "string",
487+
"pattern": "^(https|s3)://([^/]+)/?(.*)$",
488+
"maxLength": 1024,
489+
},
472490
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
473491
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
474492
# Regex based on https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_PipelineDefinitionS3Location.html
@@ -746,6 +764,13 @@ def _simple_path(*args: str):
746764
TAGS: {"$ref": "#/definitions/tags"},
747765
},
748766
},
767+
# Endpoint
768+
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpoint.html
769+
ENDPOINT: {
770+
TYPE: OBJECT,
771+
ADDITIONAL_PROPERTIES: False,
772+
PROPERTIES: {TAGS: {"$ref": "#/definitions/tags"}},
773+
},
749774
# Endpoint Config
750775
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html
751776
# Note: there is a separate API for creating Endpoints.
@@ -992,6 +1017,7 @@ def _simple_path(*args: str):
9921017
ADDITIONAL_PROPERTIES: False,
9931018
PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}},
9941019
},
1020+
RESOURCE_KEY: {"$ref": "#/definitions/kmsKeyId"},
9951021
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
9961022
TAGS: {"$ref": "#/definitions/tags"},
9971023
},

src/sagemaker/model.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
EDGE_PACKAGING_KMS_KEY_ID_PATH,
3636
EDGE_PACKAGING_ROLE_ARN_PATH,
3737
MODEL_CONTAINERS_PATH,
38+
EDGE_PACKAGING_RESOURCE_KEY_PATH,
3839
MODEL_VPC_CONFIG_PATH,
3940
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
4041
MODEL_EXECUTION_ROLE_ARN_PATH,
@@ -50,7 +51,10 @@
5051
from sagemaker.predictor import PredictorBase
5152
from sagemaker.serverless import ServerlessInferenceConfig
5253
from sagemaker.transformer import Transformer
53-
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
54+
from sagemaker.jumpstart.utils import (
55+
add_jumpstart_tags,
56+
get_jumpstart_base_name_if_jumpstart_model,
57+
)
5458
from sagemaker.utils import (
5559
unique_name_from_base,
5660
update_container_with_inference_params,
@@ -63,15 +67,25 @@
6367
from sagemaker.workflow import is_pipeline_variable
6468
from sagemaker.workflow.entities import PipelineVariable
6569
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
66-
from sagemaker.inference_recommender.inference_recommender_mixin import InferenceRecommenderMixin
70+
from sagemaker.inference_recommender.inference_recommender_mixin import (
71+
InferenceRecommenderMixin,
72+
)
6773

6874
LOGGER = logging.getLogger("sagemaker")
6975

7076
NEO_ALLOWED_FRAMEWORKS = set(
7177
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
7278
)
7379

74-
NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"]
80+
NEO_IOC_TARGET_DEVICES = [
81+
"ml_c4",
82+
"ml_c5",
83+
"ml_m4",
84+
"ml_m5",
85+
"ml_p2",
86+
"ml_p3",
87+
"ml_g4dn",
88+
]
7589

7690
NEO_MULTIVERSION_UNSUPPORTED = [
7791
"imx8mplus",
@@ -300,7 +314,9 @@ def __init__(
300314
self._base_name = None
301315
self.sagemaker_session = sagemaker_session
302316
self.role = resolve_value_from_config(
303-
role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
317+
role,
318+
MODEL_EXECUTION_ROLE_ARN_PATH,
319+
sagemaker_session=self.sagemaker_session,
304320
)
305321
self.vpc_config = resolve_value_from_config(
306322
vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
@@ -585,7 +601,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
585601
local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config)
586602

587603
bucket, key_prefix = s3.determine_bucket_and_prefix(
588-
bucket=self.bucket, key_prefix=key_prefix, sagemaker_session=self.sagemaker_session
604+
bucket=self.bucket,
605+
key_prefix=key_prefix,
606+
sagemaker_session=self.sagemaker_session,
589607
)
590608

591609
if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None:
@@ -633,7 +651,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
633651
else:
634652
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
635653
self.uploaded_code = fw_utils.UploadedCode(
636-
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point)
654+
s3_prefix=repacked_model_data,
655+
script_name=os.path.basename(self.entry_point),
637656
)
638657

639658
LOGGER.info(
@@ -693,7 +712,11 @@ def enable_network_isolation(self):
693712
return False if not self._enable_network_isolation else self._enable_network_isolation
694713

695714
def _create_sagemaker_model(
696-
self, instance_type=None, accelerator_type=None, tags=None, serverless_inference_config=None
715+
self,
716+
instance_type=None,
717+
accelerator_type=None,
718+
tags=None,
719+
serverless_inference_config=None,
697720
):
698721
"""Create a SageMaker Model Entity
699722
@@ -734,10 +757,14 @@ def _create_sagemaker_model(
734757
self._init_sagemaker_session_if_does_not_exist(instance_type)
735758
# Depending on the instance type, a local session (or) a session is initialized.
736759
self.role = resolve_value_from_config(
737-
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
760+
self.role,
761+
MODEL_EXECUTION_ROLE_ARN_PATH,
762+
sagemaker_session=self.sagemaker_session,
738763
)
739764
self.vpc_config = resolve_value_from_config(
740-
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
765+
self.vpc_config,
766+
MODEL_VPC_CONFIG_PATH,
767+
sagemaker_session=self.sagemaker_session,
741768
)
742769
self._enable_network_isolation = resolve_value_from_config(
743770
self._enable_network_isolation,
@@ -955,11 +982,16 @@ def package_for_edge(
955982
job_name = f"packaging{self._compilation_job_name[11:]}"
956983
self._init_sagemaker_session_if_does_not_exist(None)
957984
s3_kms_key = resolve_value_from_config(
958-
s3_kms_key, EDGE_PACKAGING_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
985+
s3_kms_key,
986+
EDGE_PACKAGING_KMS_KEY_ID_PATH,
987+
sagemaker_session=self.sagemaker_session,
959988
)
960989
role = resolve_value_from_config(
961990
role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
962991
)
992+
resource_key = resolve_value_from_config(
993+
resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self.sagemaker_session
994+
)
963995
if role is not None:
964996
role = self.sagemaker_session.expand_role(role)
965997
config = self._edge_packaging_job_config(
@@ -1065,7 +1097,9 @@ def compile(
10651097

10661098
self._init_sagemaker_session_if_does_not_exist(target_instance_family)
10671099
role = resolve_value_from_config(
1068-
role, COMPILATION_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
1100+
role,
1101+
COMPILATION_JOB_ROLE_ARN_PATH,
1102+
sagemaker_session=self.sagemaker_session,
10691103
)
10701104
if not role:
10711105
# Originally IAM role was a required parameter.
@@ -1232,10 +1266,14 @@ def deploy(
12321266
self._init_sagemaker_session_if_does_not_exist(instance_type)
12331267
# Depending on the instance type, a local session (or) a session is initialized.
12341268
self.role = resolve_value_from_config(
1235-
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
1269+
self.role,
1270+
MODEL_EXECUTION_ROLE_ARN_PATH,
1271+
sagemaker_session=self.sagemaker_session,
12361272
)
12371273
self.vpc_config = resolve_value_from_config(
1238-
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
1274+
self.vpc_config,
1275+
MODEL_VPC_CONFIG_PATH,
1276+
sagemaker_session=self.sagemaker_session,
12391277
)
12401278
self._enable_network_isolation = resolve_value_from_config(
12411279
self._enable_network_isolation,
@@ -1244,7 +1282,9 @@ def deploy(
12441282
)
12451283

12461284
tags = add_jumpstart_tags(
1247-
tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir
1285+
tags=tags,
1286+
inference_model_uri=self.model_data,
1287+
inference_script_uri=self.source_dir,
12481288
)
12491289

12501290
if self.role is None:
@@ -1292,7 +1332,9 @@ def deploy(
12921332
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1])
12931333
if self._is_compiled_model and not is_serverless:
12941334
self._ensure_base_name_if_needed(
1295-
image_uri=self.image_uri, script_uri=self.source_dir, model_uri=self.model_data
1335+
image_uri=self.image_uri,
1336+
script_uri=self.source_dir,
1337+
model_uri=self.model_data,
12961338
)
12971339
if self._base_name is not None:
12981340
self._base_name = "-".join((self._base_name, compiled_model_suffix))
@@ -1673,7 +1715,12 @@ class ModelPackage(Model):
16731715
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
16741716

16751717
def __init__(
1676-
self, role=None, model_data=None, algorithm_arn=None, model_package_arn=None, **kwargs
1718+
self,
1719+
role=None,
1720+
model_data=None,
1721+
algorithm_arn=None,
1722+
model_package_arn=None,
1723+
**kwargs,
16771724
):
16781725
"""Initialize a SageMaker ModelPackage.
16791726

0 commit comments

Comments
 (0)