Skip to content

Commit cb45a4b

Browse files
xinlutu2Xinlu Tu
authored andcommitted
fix: bug on AutoMLInput to allow PipelineVariable (aws#736)
Co-authored-by: Xinlu Tu <[email protected]>
1 parent 8989ae3 commit cb45a4b

File tree

6 files changed

+123
-37
lines changed

6 files changed

+123
-37
lines changed

src/sagemaker/automl/automl.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker.job import _Job
2323
from sagemaker.session import Session
2424
from sagemaker.utils import name_from_base
25+
from sagemaker.workflow.entities import PipelineVariable
2526
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
2627

2728
logger = logging.getLogger("sagemaker")
@@ -44,18 +45,20 @@ def __init__(
4445
):
4546
"""Convert an S3 Uri or a list of S3 Uri to an AutoMLInput object.
4647
47-
:param inputs (str, list[str]): a string or a list of string that points to (a)
48-
S3 location(s) where input data is stored.
49-
:param target_attribute_name (str): the target attribute name for regression
50-
or classification.
51-
:param compression (str): if training data is compressed, the compression type.
52-
The default value is None.
53-
:param channel_type (str): The channel type an enum to specify
54-
whether the input resource is for training or validation.
55-
Valid values: training or validation.
56-
:param content_type (str): The content type of the data from the input source.
57-
:param s3_data_type (str): The data type for S3 data source.
58-
Valid values: ManifestFile or S3Prefix.
48+
Args:
49+
inputs (str, list[str], PipelineVariable):
50+
a string or a list of string or a PipelineVariable that points to (a)
51+
S3 location(s) where input data is stored.
52+
target_attribute_name (str): the target attribute name for regression
53+
or classification.
54+
compression (str): if training data is compressed, the compression type.
55+
The default value is None.
56+
channel_type (str): The channel type an enum to specify
57+
whether the input resource is for training or validation.
58+
Valid values: training or validation.
59+
content_type (str): The content type of the data from the input source.
60+
s3_data_type (str): The data type for S3 data source.
61+
Valid values: ManifestFile or S3Prefix.
5962
"""
6063
self.inputs = inputs
6164
self.target_attribute_name = target_attribute_name
@@ -70,6 +73,8 @@ def to_request_dict(self):
7073
auto_ml_input = []
7174
if isinstance(self.inputs, string_types):
7275
self.inputs = [self.inputs]
76+
if isinstance(self.inputs, PipelineVariable):
77+
self.inputs = [self.inputs]
7378
for entry in self.inputs:
7479
input_entry = {
7580
"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": entry}},
@@ -106,7 +111,7 @@ def __init__(
106111
max_candidates: Optional[int] = None,
107112
max_runtime_per_training_job_in_seconds: Optional[int] = None,
108113
total_job_runtime_in_seconds: Optional[int] = None,
109-
job_objective: Optional[str] = None,
114+
job_objective: Optional[Dict[str, str]] = None,
110115
generate_candidate_definitions_only: Optional[bool] = False,
111116
tags: Optional[List[Dict[str, str]]] = None,
112117
content_type: Optional[str] = None,
@@ -142,8 +147,9 @@ def __init__(
142147
that each training job executed inside hyperparameter tuning
143148
is allowed to run as part of a hyperparameter tuning job.
144149
total_job_runtime_in_seconds (int): the total wait time of an AutoML job.
145-
job_objective (str): Defines the objective metric
150+
job_objective (dict[str, str]): Defines the objective metric
146151
used to measure the predictive quality of an AutoML job.
152+
In the format of: {"MetricName": str}
147153
generate_candidate_definitions_only (bool): Whether to generates
148154
possible candidates without training the models.
149155
tags (List[dict[str, str]]): The list of tags to attach to this
@@ -969,8 +975,10 @@ def _prepare_auto_ml_stop_condition(
969975
970976
Returns (dict): an AutoML CompletionCriteria.
971977
"""
972-
stopping_condition = {"MaxCandidates": max_candidates}
978+
stopping_condition = {}
973979

980+
if max_candidates is not None:
981+
stopping_condition["MaxCandidates"] = max_candidates
974982
if max_runtime_per_training_job_in_seconds is not None:
975983
stopping_condition[
976984
"MaxRuntimePerTrainingJobInSeconds"

src/sagemaker/workflow/automl_step.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ def __init__(
7171

7272
root_property = Properties(step_name=name, shape_name="DescribeAutoMLJobResponse")
7373

74-
best_candidate_properties = Properties(step_name=name, path="bestCandidateProperties")
75-
best_candidate_properties.__dict__["modelInsightsJsonReportPath"] = Properties(
76-
step_name=name, path="bestCandidateProperties.modelInsightsJsonReportPath"
74+
best_candidate_properties = Properties(step_name=name, path="BestCandidateProperties")
75+
best_candidate_properties.__dict__["ModelInsightsJsonReportPath"] = Properties(
76+
step_name=name, path="BestCandidateProperties.ModelInsightsJsonReportPath"
7777
)
78-
best_candidate_properties.__dict__["explainabilityJsonReportPath"] = Properties(
79-
step_name=name, path="bestCandidateProperties.explainabilityJsonReportPath"
78+
best_candidate_properties.__dict__["ExplainabilityJsonReportPath"] = Properties(
79+
step_name=name, path="BestCandidateProperties.ExplainabilityJsonReportPath"
8080
)
8181

82-
root_property.__dict__["bestCandidateProperties"] = best_candidate_properties
82+
root_property.__dict__["BestCandidateProperties"] = best_candidate_properties
8383
self._properties = root_property
8484

8585
@property

tests/integ/sagemaker/workflow/test_automl_steps.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
import os
1616

17+
import boto3
1718
import pytest
1819
from botocore.exceptions import WaiterError
1920

21+
from sagemaker.workflow import ParameterString
2022
from sagemaker.workflow.automl_step import AutoMLStep
2123
from sagemaker.automl.automl import AutoML, AutoMLInput
2224

23-
from sagemaker import utils, get_execution_role
24-
from sagemaker.utils import unique_name_from_base
25+
from sagemaker import utils, get_execution_role, ModelMetrics, MetricsSource
2526
from sagemaker.workflow.model_step import ModelStep
2627
from sagemaker.workflow.pipeline import Pipeline
2728

@@ -50,10 +51,8 @@ def test_automl_step(pipeline_session, role, pipeline_name):
5051
role=role,
5152
target_attribute_name=TARGET_ATTRIBUTE_NAME,
5253
sagemaker_session=pipeline_session,
53-
max_candidates=1,
5454
mode=MODE,
5555
)
56-
job_name = unique_name_from_base("auto-ml", max_length=32)
5756
s3_input_training = pipeline_session.upload_data(
5857
path=TRAINING_DATA, key_prefix=PREFIX + "/input"
5958
)
@@ -72,27 +71,56 @@ def test_automl_step(pipeline_session, role, pipeline_name):
7271
)
7372
inputs = [input_training, input_validation]
7473

75-
step_args = auto_ml.fit(inputs=inputs, job_name=job_name)
74+
step_args = auto_ml.fit(inputs=inputs)
7675

7776
automl_step = AutoMLStep(
7877
name="MyAutoMLStep",
7978
step_args=step_args,
8079
)
8180

8281
automl_model = automl_step.get_best_auto_ml_model(sagemaker_session=pipeline_session, role=role)
83-
8482
step_args_create_model = automl_model.create(
8583
instance_type="c4.4xlarge",
8684
)
87-
8885
automl_model_step = ModelStep(
8986
name="MyAutoMLModelStep",
9087
step_args=step_args_create_model,
9188
)
9289

90+
model_package_group_name = ParameterString(
91+
name="ModelPackageName", default_value="AutoMlModelPackageGroup"
92+
)
93+
model_approval_status = ParameterString(name="ModelApprovalStatus", default_value="Approved")
94+
model_metrics = ModelMetrics(
95+
model_statistics=MetricsSource(
96+
s3_uri=automl_step.properties.BestCandidateProperties.ModelInsightsJsonReportPath,
97+
content_type="application/json",
98+
),
99+
explainability=MetricsSource(
100+
s3_uri=automl_step.properties.BestCandidateProperties.ExplainabilityJsonReportPath,
101+
content_type="application/json",
102+
),
103+
)
104+
step_args_register_model = automl_model.register(
105+
content_types=["text/csv"],
106+
response_types=["text/csv"],
107+
inference_instances=["ml.m5.xlarge"],
108+
transform_instances=["ml.m5.xlarge"],
109+
model_package_group_name=model_package_group_name,
110+
approval_status=model_approval_status,
111+
model_metrics=model_metrics,
112+
)
113+
register_model_step = ModelStep(
114+
name="ModelRegistrationStep", step_args=step_args_register_model
115+
)
116+
93117
pipeline = Pipeline(
94118
name=pipeline_name,
95-
steps=[automl_step, automl_model_step],
119+
parameters=[
120+
model_approval_status,
121+
model_package_group_name,
122+
],
123+
steps=[automl_step, automl_model_step, register_model_step],
96124
sagemaker_session=pipeline_session,
97125
)
98126

@@ -114,9 +142,20 @@ def test_automl_step(pipeline_session, role, pipeline_name):
114142
assert step["Metadata"]["AutoMLJob"]["Arn"] is not None
115143

116144
assert has_automl_job
117-
assert len(execution_steps) == 2
145+
assert len(execution_steps) == 3
118146
finally:
119147
try:
148+
sagemaker_client = boto3.client("sagemaker")
149+
for model_package in sagemaker_client.list_model_packages(
150+
ModelPackageGroupName="AutoMlModelPackageGroup"
151+
)["ModelPackageSummaryList"]:
152+
sagemaker_client.delete_model_package(
153+
ModelPackageName=model_package["ModelPackageArn"]
154+
)
155+
sagemaker_client.delete_model_package_group(
156+
ModelPackageGroupName="AutoMlModelPackageGroup"
157+
)
158+
120159
pipeline.delete()
121160
except Exception:
122161
pass

tests/integ/test_auto_ml.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def test_auto_ml_fit_local_input(sagemaker_session):
7878
role=ROLE,
7979
target_attribute_name=TARGET_ATTRIBUTE_NAME,
8080
sagemaker_session=sagemaker_session,
81-
max_candidates=1,
8281
generate_candidate_definitions_only=True,
8382
)
8483

tests/unit/sagemaker/automl/test_auto_ml.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from mock import Mock, patch
1919
from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator, PipelineModel
2020
from sagemaker.predictor import Predictor
21+
from sagemaker.workflow.functions import Join
2122

2223
MODEL_DATA = "s3://bucket/model.tar.gz"
2324
MODEL_IMAGE = "mi"
@@ -52,7 +53,7 @@
5253
MAX_RUNTIME_PER_TRAINING_JOB = 3600
5354
TOTAL_JOB_RUNTIME = 36000
5455
TARGET_OBJECTIVE = "0.01"
55-
JOB_OBJECTIVE = {"fake job objective"}
56+
JOB_OBJECTIVE = {"MetricName": "F1"}
5657
TAGS = [{"Name": "some-tag", "Value": "value-for-tag"}]
5758
CONTENT_TYPE = "x-application/vnd.amazon+parquet"
5859
S3_DATA_TYPE = "ManifestFile"
@@ -503,7 +504,46 @@ def test_auto_ml_default_fit(strftime, sagemaker_session):
503504
],
504505
"output_config": {"S3OutputPath": DEFAULT_OUTPUT_PATH},
505506
"auto_ml_job_config": {
506-
"CompletionCriteria": {"MaxCandidates": DEFAULT_MAX_CANDIDATES},
507+
"CompletionCriteria": {},
508+
"SecurityConfig": {
509+
"EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC
510+
},
511+
},
512+
"role": ROLE,
513+
"job_name": DEFAULT_JOB_NAME,
514+
"problem_type": None,
515+
"job_objective": None,
516+
"generate_candidate_definitions_only": GENERATE_CANDIDATE_DEFINITIONS_ONLY,
517+
"tags": None,
518+
}
519+
520+
521+
@patch("time.strftime", return_value=TIMESTAMP)
522+
def test_auto_ml_default_fit_with_pipeline_variable(strftime, sagemaker_session):
523+
auto_ml = AutoML(
524+
role=ROLE,
525+
target_attribute_name=TARGET_ATTRIBUTE_NAME,
526+
sagemaker_session=sagemaker_session,
527+
)
528+
inputs = Join(on="/", values=[DEFAULT_S3_INPUT_DATA, "ProcessingJobName"])
529+
auto_ml.fit(inputs=AutoMLInput(inputs=inputs, target_attribute_name=TARGET_ATTRIBUTE_NAME))
530+
sagemaker_session.auto_ml.assert_called_once()
531+
_, args = sagemaker_session.auto_ml.call_args
532+
assert args == {
533+
"input_config": [
534+
{
535+
"DataSource": {
536+
"S3DataSource": {
537+
"S3DataType": "S3Prefix",
538+
"S3Uri": Join(on="/", values=["s3://mybucket/data", "ProcessingJobName"]),
539+
}
540+
},
541+
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
542+
}
543+
],
544+
"output_config": {"S3OutputPath": DEFAULT_OUTPUT_PATH},
545+
"auto_ml_job_config": {
546+
"CompletionCriteria": {},
507547
"SecurityConfig": {
508548
"EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC
509549
},

tests/unit/sagemaker/workflow/test_automl_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,11 @@ def test_single_automl_step_with_parameter(pipeline_session):
231231
step_args=step_args,
232232
)
233233

234-
assert automl_step.properties.bestCandidateProperties.modelInsightsJsonReportPath.expr == {
235-
"Get": "Steps.MyAutoMLStep.bestCandidateProperties.modelInsightsJsonReportPath"
234+
assert automl_step.properties.BestCandidateProperties.ModelInsightsJsonReportPath.expr == {
235+
"Get": "Steps.MyAutoMLStep.BestCandidateProperties.ModelInsightsJsonReportPath"
236236
}
237-
assert automl_step.properties.bestCandidateProperties.explainabilityJsonReportPath.expr == {
238-
"Get": "Steps.MyAutoMLStep.bestCandidateProperties.explainabilityJsonReportPath"
237+
assert automl_step.properties.BestCandidateProperties.ExplainabilityJsonReportPath.expr == {
238+
"Get": "Steps.MyAutoMLStep.BestCandidateProperties.ExplainabilityJsonReportPath"
239239
}
240240

241241
pipeline = Pipeline(

0 commit comments

Comments
 (0)