Skip to content

Commit 7391fa1

Browse files
xinlutu2Xinlu Tu
authored andcommitted
fix: Add more integ test logic for AutoMLStep
Co-authored-by: Xinlu Tu <[email protected]>
1 parent 30f014d commit 7391fa1

File tree

2 files changed

+51
-18
lines changed

2 files changed

+51
-18
lines changed

src/sagemaker/automl/automl.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,16 @@ def __init__(
4949
inputs (str, list[str], PipelineVariable):
5050
a string or a list of string or a PipelineVariable that points to (a)
5151
S3 location(s) where input data is stored.
52-
target_attribute_name (str): the target attribute name for regression
53-
or classification.
54-
compression (str): if training data is compressed, the compression type.
55-
The default value is None.
56-
channel_type (str): The channel type an enum to specify
52+
target_attribute_name (str, PipelineVariable):
53+
the target attribute name for regression or classification.
54+
compression (str, PipelineVariable):
55+
if training data is compressed, the compression type. The default value is None.
56+
channel_type (str, PipelineVariable): The channel type an enum to specify
5757
whether the input resource is for training or validation.
5858
Valid values: training or validation.
59-
content_type (str): The content type of the data from the input source.
60-
s3_data_type (str): The data type for S3 data source.
59+
content_type (str, PipelineVariable):
60+
The content type of the data from the input source.
61+
s3_data_type (str, PipelineVariable): The data type for S3 data source.
6162
Valid values: ManifestFile or S3Prefix.
6263
"""
6364
self.inputs = inputs

tests/integ/sagemaker/workflow/test_automl_steps.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from __future__ import absolute_import
1414

1515
import os
16+
import re
1617

17-
import boto3
1818
import pytest
1919
from botocore.exceptions import WaiterError
2020

@@ -91,13 +91,19 @@ def test_automl_step(pipeline_session, role, pipeline_name):
9191
name="ModelPackageName", default_value="AutoMlModelPackageGroup"
9292
)
9393
model_approval_status = ParameterString(name="ModelApprovalStatus", default_value="Approved")
94+
model_insights_json_report_path = (
95+
automl_step.properties.BestCandidateProperties.ModelInsightsJsonReportPath
96+
)
97+
explainability_json_report_path = (
98+
automl_step.properties.BestCandidateProperties.ExplainabilityJsonReportPath
99+
)
94100
model_metrics = ModelMetrics(
95101
model_statistics=MetricsSource(
96-
s3_uri=automl_step.properties.BestCandidateProperties.ModelInsightsJsonReportPath,
102+
s3_uri=model_insights_json_report_path,
97103
content_type="application/json",
98104
),
99105
explainability=MetricsSource(
100-
s3_uri=automl_step.properties.BestCandidateProperties.ExplainabilityJsonReportPath,
106+
s3_uri=explainability_json_report_path,
101107
content_type="application/json",
102108
),
103109
)
@@ -139,19 +145,45 @@ def test_automl_step(pipeline_session, role, pipeline_name):
139145
assert step["StepStatus"] == "Succeeded"
140146
if "AutoMLJob" in step["Metadata"]:
141147
has_automl_job = True
142-
assert step["Metadata"]["AutoMLJob"]["Arn"] is not None
148+
automl_job_arn = step["Metadata"]["AutoMLJob"]["Arn"]
149+
assert automl_job_arn is not None
150+
automl_job_name = re.findall(r"(?<=automl-job/).*", automl_job_arn)[0]
151+
auto_ml_desc = auto_ml.describe_auto_ml_job(job_name=automl_job_name)
152+
model_insights_json_from_automl = (
153+
auto_ml_desc["BestCandidate"]["CandidateProperties"][
154+
"CandidateArtifactLocations"
155+
]["ModelInsights"]
156+
+ "/statistics.json"
157+
)
158+
explainability_json_from_automl = (
159+
auto_ml_desc["BestCandidate"]["CandidateProperties"][
160+
"CandidateArtifactLocations"
161+
]["Explainability"]
162+
+ "/analysis.json"
163+
)
143164

144165
assert has_automl_job
145166
assert len(execution_steps) == 3
167+
sagemaker_client = pipeline_session.boto_session.client("sagemaker")
168+
model_package = sagemaker_client.list_model_packages(
169+
ModelPackageGroupName="AutoMlModelPackageGroup"
170+
)["ModelPackageSummaryList"][0]
171+
response = sagemaker_client.describe_model_package(
172+
ModelPackageName=model_package["ModelPackageArn"]
173+
)
174+
model_insights_json_report_path = response["ModelMetrics"]["ModelQuality"]["Statistics"][
175+
"S3Uri"
176+
]
177+
explainability_json_report_path = response["ModelMetrics"]["Explainability"]["Report"][
178+
"S3Uri"
179+
]
180+
181+
assert model_insights_json_report_path == model_insights_json_from_automl
182+
assert explainability_json_report_path == explainability_json_from_automl
183+
146184
finally:
147185
try:
148-
sagemaker_client = boto3.client("sagemaker")
149-
for model_package in sagemaker_client.list_model_packages(
150-
ModelPackageGroupName="AutoMlModelPackageGroup"
151-
)["ModelPackageSummaryList"]:
152-
sagemaker_client.delete_model_package(
153-
ModelPackageName=model_package["ModelPackageArn"]
154-
)
186+
sagemaker_client.delete_model_package(ModelPackageName=model_package["ModelPackageArn"])
155187
sagemaker_client.delete_model_package_group(
156188
ModelPackageGroupName="AutoMlModelPackageGroup"
157189
)

0 commit comments

Comments
 (0)