Skip to content

Commit 6c9bc6a

Browse files
rubanhRuban Hussain
andcommitted
intelligent defaults - tags and encryption (aws#842)
* feature: sagemaker config - support tags for all APIs * feature: sagemaker config - support EnableInterContainerTrafficEncryption for relevant APIs --------- Co-authored-by: Ruban Hussain <[email protected]>
1 parent cc090d4 commit 6c9bc6a

34 files changed

+1328
-28
lines changed

src/sagemaker/automl/automl.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
from sagemaker import Model, PipelineModel
2121
from sagemaker.automl.candidate_estimator import CandidateEstimator
22+
from sagemaker.config.config_schema import (
23+
PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION,
24+
)
2225
from sagemaker.job import _Job
2326
from sagemaker.session import Session
2427
from sagemaker.utils import name_from_base
@@ -106,7 +109,7 @@ def __init__(
106109
compression_type: Optional[str] = None,
107110
sagemaker_session: Optional[Session] = None,
108111
volume_kms_key: Optional[str] = None,
109-
encrypt_inter_container_traffic: Optional[bool] = False,
112+
encrypt_inter_container_traffic: Optional[bool] = None,
110113
vpc_config: Optional[Dict[str, List]] = None,
111114
problem_type: Optional[str] = None,
112115
max_candidates: Optional[int] = None,
@@ -182,7 +185,6 @@ def __init__(
182185
self.base_job_name = base_job_name
183186
self.compression_type = compression_type
184187
self.volume_kms_key = volume_kms_key
185-
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
186188
self.vpc_config = vpc_config
187189
self.problem_type = problem_type
188190
self.max_candidate = max_candidates
@@ -205,6 +207,12 @@ def __init__(
205207
self._best_candidate = None
206208
self.sagemaker_session = sagemaker_session or Session()
207209

210+
self.encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config(
211+
direct_input=encrypt_inter_container_traffic,
212+
config_path=PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION,
213+
default_value=False,
214+
)
215+
208216
self._check_problem_type_and_job_objective(self.problem_type, self.job_objective)
209217

210218
@runnable_by_pipeline
@@ -276,6 +284,8 @@ def attach(cls, auto_ml_job_name, sagemaker_session=None):
276284
volume_kms_key=auto_ml_job_desc.get("AutoMLJobConfig", {})
277285
.get("SecurityConfig", {})
278286
.get("VolumeKmsKeyId"),
287+
# Do not override encrypt_inter_container_traffic from config because this info
288+
# is pulled from an existing automl job
279289
encrypt_inter_container_traffic=auto_ml_job_desc.get("AutoMLJobConfig", {})
280290
.get("SecurityConfig", {})
281291
.get("EnableInterContainerTrafficEncryption", False),

src/sagemaker/automl/candidate_estimator.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from six import string_types
1717

1818
from sagemaker import Session
19+
from sagemaker.config.config_schema import PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION
1920
from sagemaker.job import _Job
2021
from sagemaker.utils import name_from_base
2122

@@ -72,7 +73,8 @@ def fit(
7273
inputs,
7374
candidate_name=None,
7475
volume_kms_key=None,
75-
encrypt_inter_container_traffic=False,
76+
# default of False for training job, checked inside function
77+
encrypt_inter_container_traffic=None,
7678
vpc_config=None,
7779
wait=True,
7880
logs=True,
@@ -87,7 +89,8 @@ def fit(
8789
volume_kms_key (str): The KMS key id to encrypt data on the storage volume attached to
8890
the ML compute instance(s).
8991
encrypt_inter_container_traffic (bool): To encrypt all communications between ML compute
90-
instances in distributed training. Default: False.
92+
instances in distributed training. If not passed, will be fetched from
93+
sagemaker_config. Default: False.
9194
vpc_config (dict): Specifies a VPC that jobs and hosted models have access to.
9295
Control access to and from training and model containers by configuring the VPC
9396
wait (bool): Whether the call should wait until all jobs completes (default: True).
@@ -131,12 +134,21 @@ def fit(
131134
base_name = "sagemaker-automl-training-rerun"
132135
step_name = name_from_base(base_name)
133136
step["name"] = step_name
137+
138+
# Check training_job config not auto_ml_job config because this function calls
139+
# training job API
140+
_encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config(
141+
direct_input=encrypt_inter_container_traffic,
142+
config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION,
143+
default_value=False,
144+
)
145+
134146
train_args = self._get_train_args(
135147
desc,
136148
channels,
137149
step_name,
138150
volume_kms_key,
139-
encrypt_inter_container_traffic,
151+
_encrypt_inter_container_traffic,
140152
vpc_config,
141153
)
142154
self.sagemaker_session.train(**train_args)

src/sagemaker/config/config_schema.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
VOLUME_KMS_KEY_ID = "VolumeKmsKeyId"
2020
KMS_KEY_ID = "KmsKeyId"
2121
ROLE_ARN = "RoleArn"
22+
TAGS = "Tags"
23+
KEY = "Key"
24+
VALUE = "Value"
2225
EXECUTION_ROLE_ARN = "ExecutionRoleArn"
2326
CLUSTER_ROLE_ARN = "ClusterRoleArn"
2427
VPC_CONFIG = "VpcConfig"
@@ -73,6 +76,35 @@
7376
TYPE = "type"
7477
OBJECT = "object"
7578
ADDITIONAL_PROPERTIES = "additionalProperties"
79+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = "EnableInterContainerTrafficEncryption"
80+
81+
82+
def _simple_path(*args: str):
83+
"""Appends an arbitrary number of strings to use as path constants"""
84+
return ".".join(args)
85+
86+
87+
# Paths for reference elsewhere in the SDK.
88+
# Names include the schema version since the paths could change with other schema versions
89+
PATH_V1_MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION = _simple_path(
90+
SAGEMAKER,
91+
MONITORING_SCHEDULE,
92+
MONITORING_SCHEDULE_CONFIG,
93+
MONITORING_JOB_DEFINITION,
94+
NETWORK_CONFIG,
95+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
96+
)
97+
PATH_V1_AUTO_ML_INTER_CONTAINER_ENCRYPTION = _simple_path(
98+
SAGEMAKER, AUTO_ML, SECURITY_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
99+
)
100+
PATH_V1_PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path(
101+
SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
102+
)
103+
PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION = _simple_path(
104+
SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
105+
)
106+
107+
76108
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
77109
"$schema": "https://json-schema.org/draft/2020-12/schema",
78110
TYPE: OBJECT,
@@ -164,6 +196,31 @@
164196
}
165197
},
166198
},
199+
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_Tag.html
200+
"tags": {
201+
TYPE: "array",
202+
"items": {
203+
TYPE: OBJECT,
204+
ADDITIONAL_PROPERTIES: False,
205+
PROPERTIES: {
206+
KEY: {
207+
TYPE: "string",
208+
"pattern": r"^[\w\s\d_.:/=+\-@]*$",
209+
"minLength": 1,
210+
"maxLength": 128,
211+
},
212+
VALUE: {
213+
TYPE: "string",
214+
"pattern": r"^[\w\s\d_.:/=+\-@]*$",
215+
"minLength": 0,
216+
"maxLength": 256,
217+
},
218+
},
219+
},
220+
"minItems": 0,
221+
"maxItems": 50,
222+
},
223+
SUBNETS: {TYPE: "array", "items": {"$ref": "#/definitions/subnet"}},
167224
},
168225
PROPERTIES: {
169226
SCHEMA_VERSION: {
@@ -219,6 +276,7 @@
219276
},
220277
},
221278
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
279+
TAGS: {"$ref": "#/definitions/tags"},
222280
},
223281
},
224282
# Monitoring Schedule
@@ -257,6 +315,9 @@
257315
TYPE: OBJECT,
258316
ADDITIONAL_PROPERTIES: False,
259317
PROPERTIES: {
318+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {
319+
TYPE: "boolean"
320+
},
260321
ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"},
261322
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
262323
},
@@ -265,7 +326,8 @@
265326
},
266327
}
267328
},
268-
}
329+
},
330+
TAGS: {"$ref": "#/definitions/tags"},
269331
},
270332
},
271333
# Endpoint Config
@@ -302,6 +364,7 @@
302364
TYPE: "array",
303365
"items": {"$ref": "#/definitions/productionVariant"},
304366
},
367+
TAGS: {"$ref": "#/definitions/tags"},
305368
},
306369
},
307370
# Auto ML
@@ -318,6 +381,9 @@
318381
TYPE: OBJECT,
319382
ADDITIONAL_PROPERTIES: False,
320383
PROPERTIES: {
384+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {
385+
TYPE: "boolean"
386+
},
321387
VOLUME_KMS_KEY_ID: {
322388
TYPE: "string",
323389
},
@@ -332,6 +398,7 @@
332398
PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}},
333399
},
334400
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
401+
TAGS: {"$ref": "#/definitions/tags"},
335402
},
336403
},
337404
# Transform Job
@@ -355,6 +422,7 @@
355422
ADDITIONAL_PROPERTIES: False,
356423
PROPERTIES: {VOLUME_KMS_KEY_ID: {TYPE: "string"}},
357424
},
425+
TAGS: {"$ref": "#/definitions/tags"},
358426
},
359427
},
360428
# Compilation Job
@@ -371,14 +439,18 @@
371439
},
372440
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
373441
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
442+
TAGS: {"$ref": "#/definitions/tags"},
374443
},
375444
},
376445
# Pipeline
377446
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreatePipeline.html
378447
PIPELINE: {
379448
TYPE: OBJECT,
380449
ADDITIONAL_PROPERTIES: False,
381-
PROPERTIES: {ROLE_ARN: {"$ref": "#/definitions/roleArn"}},
450+
PROPERTIES: {
451+
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
452+
TAGS: {"$ref": "#/definitions/tags"},
453+
},
382454
},
383455
# Model
384456
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html
@@ -389,6 +461,7 @@
389461
ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"},
390462
EXECUTION_ROLE_ARN: {"$ref": "#/definitions/roleArn"},
391463
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
464+
TAGS: {"$ref": "#/definitions/tags"},
392465
},
393466
},
394467
# Model Package
@@ -407,7 +480,8 @@
407480
},
408481
VALIDATION_ROLE: {"$ref": "#/definitions/roleArn"},
409482
},
410-
}
483+
},
484+
TAGS: {"$ref": "#/definitions/tags"},
411485
},
412486
},
413487
# Processing Job
@@ -420,6 +494,7 @@
420494
TYPE: OBJECT,
421495
ADDITIONAL_PROPERTIES: False,
422496
PROPERTIES: {
497+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {TYPE: "boolean"},
423498
ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"},
424499
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
425500
},
@@ -445,6 +520,7 @@
445520
},
446521
},
447522
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
523+
TAGS: {"$ref": "#/definitions/tags"},
448524
},
449525
},
450526
# Training Job
@@ -453,6 +529,7 @@
453529
TYPE: OBJECT,
454530
ADDITIONAL_PROPERTIES: False,
455531
PROPERTIES: {
532+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {TYPE: "boolean"},
456533
ENABLE_NETWORK_ISOLATION: {TYPE: "boolean"},
457534
OUTPUT_DATA_CONFIG: {
458535
TYPE: OBJECT,
@@ -466,6 +543,7 @@
466543
},
467544
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
468545
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
546+
TAGS: {"$ref": "#/definitions/tags"},
469547
},
470548
},
471549
# Edge Packaging Job
@@ -480,6 +558,7 @@
480558
PROPERTIES: {KMS_KEY_ID: {TYPE: "string"}},
481559
},
482560
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
561+
TAGS: {"$ref": "#/definitions/tags"},
483562
},
484563
},
485564
},

src/sagemaker/estimator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import sagemaker
3030
from sagemaker import git_utils, image_uris, vpc_utils
3131
from sagemaker.analytics import TrainingJobAnalytics
32+
from sagemaker.config.config_schema import PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION
3233
from sagemaker.debugger import ( # noqa: F401 # pylint: disable=unused-import
3334
DEBUGGER_FLAG,
3435
DebuggerHookConfig,
@@ -133,7 +134,7 @@ def __init__(
133134
model_uri: Optional[str] = None,
134135
model_channel_name: Union[str, PipelineVariable] = "model",
135136
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
136-
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False,
137+
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None,
137138
use_spot_instances: Union[bool, PipelineVariable] = False,
138139
max_wait: Optional[Union[int, PipelineVariable]] = None,
139140
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
@@ -598,7 +599,12 @@ def __init__(
598599
training_repository_credentials_provider_arn
599600
)
600601

601-
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
602+
self.encrypt_inter_container_traffic = self.sagemaker_session.resolve_value_from_config(
603+
direct_input=encrypt_inter_container_traffic,
604+
config_path=PATH_V1_TRAINING_JOB_INTER_CONTAINER_ENCRYPTION,
605+
default_value=False,
606+
)
607+
602608
self.use_spot_instances = use_spot_instances
603609
self.max_wait = max_wait
604610
self.checkpoint_s3_uri = checkpoint_s3_uri
@@ -2168,6 +2174,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
21682174

21692175
# encrypt_inter_container_traffic may be a pipeline variable place holder object
21702176
# which is parsed in execution time
2177+
# This does not check config because the EstimatorBase constuctor already did that check
21712178
if estimator.encrypt_inter_container_traffic:
21722179
train_args[
21732180
"encrypt_inter_container_traffic"
@@ -2745,6 +2752,7 @@ def __init__(
27452752
model_uri=model_uri,
27462753
model_channel_name=model_channel_name,
27472754
metric_definitions=metric_definitions,
2755+
# Does not check sagemaker config because EstimatorBase will do that check
27482756
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
27492757
use_spot_instances=use_spot_instances,
27502758
max_wait=max_wait,

0 commit comments

Comments
 (0)