Skip to content

Commit 523d01d

Browse files
author
Ruban Hussain
committed
fix: update Schema to match exactly with APIs
1 parent d8e33ea commit 523d01d

File tree

9 files changed

+95
-78
lines changed

9 files changed

+95
-78
lines changed

src/sagemaker/config/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
AUTO_ML_ROLE_ARN_PATH,
4040
AUTO_ML_OUTPUT_CONFIG_PATH,
4141
AUTO_ML_JOB_CONFIG_PATH,
42-
AUTO_ML,
42+
AUTO_ML_JOB,
4343
COMPILATION_JOB_ROLE_ARN_PATH,
4444
COMPILATION_JOB_OUTPUT_CONFIG_PATH,
4545
COMPILATION_JOB_VPC_CONFIG_PATH,

src/sagemaker/config/config_schema.py

Lines changed: 64 additions & 46 deletions
Large diffs are not rendered by default.

src/sagemaker/session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
AUTO_ML_ROLE_ARN_PATH,
6060
AUTO_ML_OUTPUT_CONFIG_PATH,
6161
AUTO_ML_JOB_CONFIG_PATH,
62-
AUTO_ML,
62+
AUTO_ML_JOB,
6363
COMPILATION_JOB_ROLE_ARN_PATH,
6464
COMPILATION_JOB_OUTPUT_CONFIG_PATH,
6565
COMPILATION_JOB_VPC_CONFIG_PATH,
@@ -1212,10 +1212,14 @@ def process(
12121212
# Processing Input can either have AthenaDatasetDefinition or RedshiftDatasetDefinition
12131213
# or neither, but not both
12141214
union_key_paths_for_dataset_definition = [
1215+
[
1216+
"DatasetDefinition",
1217+
"S3Input",
1218+
],
12151219
[
12161220
"DatasetDefinition.AthenaDatasetDefinition",
12171221
"DatasetDefinition.RedshiftDatasetDefinition",
1218-
]
1222+
],
12191223
]
12201224
update_list_of_dicts_with_values_from_config(
12211225
inputs,
@@ -2193,7 +2197,9 @@ def _get_auto_ml_request(
21932197
auto_ml_job_request["ProblemType"] = problem_type
21942198

21952199
tags = _append_project_tags(tags)
2196-
tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML, TAGS))
2200+
tags = self._append_sagemaker_config_tags(
2201+
tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML_JOB, TAGS)
2202+
)
21972203
if tags is not None:
21982204
auto_ml_job_request["Tags"] = tags
21992205

tests/data/config/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ SageMaker:
3434
ProductionVariants:
3535
- CoreDumpConfig:
3636
KmsKeyId: 'kmskeyid4'
37-
AutoML:
37+
AutoMLJob:
3838
AutoMLJobConfig:
3939
SecurityConfig:
4040
VolumeKmsKeyId: 'volumekmskeyid1'

tests/unit/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
ENDPOINT_CONFIG,
4141
DATA_CAPTURE_CONFIG,
4242
PRODUCTION_VARIANTS,
43-
AUTO_ML,
43+
AUTO_ML_JOB,
4444
AUTO_ML_JOB_CONFIG,
4545
SECURITY_CONFIG,
4646
OUTPUT_DATA_CONFIG,
@@ -149,7 +149,7 @@
149149
SAGEMAKER_CONFIG_AUTO_ML = {
150150
SCHEMA_VERSION: "1.0",
151151
SAGEMAKER: {
152-
AUTO_ML: {
152+
AUTO_ML_JOB: {
153153
AUTO_ML_JOB_CONFIG: {
154154
SECURITY_CONFIG: {
155155
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: True,

tests/unit/sagemaker/automl/test_auto_ml.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,18 +305,18 @@ def test_framework_initialization_with_sagemaker_config_injection(sagemaker_sess
305305
sagemaker_session=sagemaker_session,
306306
)
307307

308-
expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][
309-
"SecurityConfig"
310-
]["VolumeKmsKeyId"]
311-
expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["RoleArn"]
312-
expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["OutputDataConfig"][
308+
expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"][
309+
"AutoMLJobConfig"
310+
]["SecurityConfig"]["VolumeKmsKeyId"]
311+
expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["RoleArn"]
312+
expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["OutputDataConfig"][
313313
"KmsKeyId"
314314
]
315-
expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][
315+
expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["AutoMLJobConfig"][
316316
"SecurityConfig"
317317
]["VpcConfig"]
318318
expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"][
319-
"AutoML"
319+
"AutoMLJob"
320320
]["AutoMLJobConfig"]["SecurityConfig"]["EnableInterContainerTrafficEncryption"]
321321
assert auto_ml.role == expected_role_arn
322322
assert auto_ml.output_kms_key == expected_kms_key_id

tests/unit/sagemaker/config/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def valid_config_with_all_the_scopes(
183183
"FeatureGroup": valid_feature_group_config,
184184
"MonitoringSchedule": valid_monitoring_schedule_config,
185185
"EndpointConfig": valid_endpointconfig_config,
186-
"AutoML": valid_automl_config,
186+
"AutoMLJob": valid_automl_config,
187187
"TransformJob": valid_transform_job_config,
188188
"CompilationJob": valid_compilation_job_config,
189189
"Pipeline": valid_pipeline_config,

tests/unit/sagemaker/config/test_config_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_valid_transform_job_schema(base_config_with_schema, valid_transform_job
8080

8181

8282
def test_valid_automl_schema(base_config_with_schema, valid_automl_config):
83-
_validate_config(base_config_with_schema, {"AutoML": valid_automl_config})
83+
_validate_config(base_config_with_schema, {"AutoMLJob": valid_automl_config})
8484

8585

8686
def test_valid_endpoint_config_schema(base_config_with_schema, valid_endpointconfig_config):

tests/unit/test_session.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,7 @@ def test_create_process_with_sagemaker_config_injection(sagemaker_session):
274274
processing_inputs = [
275275
{
276276
"InputName": "input-1",
277-
"S3Input": {
278-
"S3Uri": "mocked_s3_uri_from_upload_data",
279-
"LocalPath": "/container/path/",
280-
"S3DataType": "Archive",
281-
"S3InputMode": "File",
282-
"S3DataDistributionType": "FullyReplicated",
283-
"S3CompressionType": "None",
284-
},
277+
# No S3Input because the API expects only one of S3Input or DatasetDefinition
285278
"DatasetDefinition": {
286279
"AthenaDatasetDefinition": {},
287280
},
@@ -3184,19 +3177,19 @@ def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_session):
31843177
job_name = JOB_NAME
31853178
sagemaker_session.auto_ml(input_config, output_config, auto_ml_job_config, job_name=job_name)
31863179
expected_call_args = copy.deepcopy(DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS)
3187-
expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][
3188-
"SecurityConfig"
3189-
]["VolumeKmsKeyId"]
3190-
expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["RoleArn"]
3191-
expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["OutputDataConfig"][
3180+
expected_volume_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"][
3181+
"AutoMLJobConfig"
3182+
]["SecurityConfig"]["VolumeKmsKeyId"]
3183+
expected_role_arn = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["RoleArn"]
3184+
expected_kms_key_id = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["OutputDataConfig"][
31923185
"KmsKeyId"
31933186
]
3194-
expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["AutoMLJobConfig"][
3187+
expected_vpc_config = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["AutoMLJobConfig"][
31953188
"SecurityConfig"
31963189
]["VpcConfig"]
3197-
expected_tags = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoML"]["Tags"]
3190+
expected_tags = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"]["AutoMLJob"]["Tags"]
31983191
expected_enable_inter_container_traffic_encryption = SAGEMAKER_CONFIG_AUTO_ML["SageMaker"][
3199-
"AutoML"
3192+
"AutoMLJob"
32003193
]["AutoMLJobConfig"]["SecurityConfig"]["EnableInterContainerTrafficEncryption"]
32013194
expected_call_args["OutputDataConfig"]["KmsKeyId"] = expected_kms_key_id
32023195
expected_call_args["RoleArn"] = expected_role_arn

0 commit comments

Comments
 (0)