Skip to content

Commit 78d0b3a

Browse files
nmadanNamrata Madanahsan-z-khanicywang86rui
authored and
Chia-Eng
committed
fix: use workflow parameters in training hyperparameters (aws#2114) (aws#2115) (aws#2227)
Co-authored-by: Namrata Madan <[email protected]> Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: icywang86rui <[email protected]>
1 parent d62fe38 commit 78d0b3a

File tree

4 files changed

+47
-18
lines changed

4 files changed

+47
-18
lines changed

src/sagemaker/estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252
_region_supports_profiler,
5353
get_mp_parameters,
5454
)
55+
from sagemaker.workflow.properties import Properties
56+
from sagemaker.workflow.parameters import Parameter
57+
from sagemaker.workflow.entities import Expression
5558
from sagemaker.inputs import TrainingInput
5659
from sagemaker.job import _Job
5760
from sagemaker.local import LocalSession
@@ -1460,7 +1463,10 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
14601463

14611464
current_hyperparameters = estimator.hyperparameters()
14621465
if current_hyperparameters is not None:
1463-
hyperparameters = {str(k): str(v) for (k, v) in current_hyperparameters.items()}
1466+
hyperparameters = {
1467+
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v))
1468+
for (k, v) in current_hyperparameters.items()
1469+
}
14641470

14651471
train_args = config.copy()
14661472
train_args["input_mode"] = estimator.input_mode

src/sagemaker/processing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sagemaker.session import Session
3333
from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
3434
from sagemaker.workflow.properties import Properties
35+
from sagemaker.workflow.parameters import Parameter
3536
from sagemaker.workflow.entities import Expression
3637
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
3738
from sagemaker.apiutils._base_types import ApiObject
@@ -292,7 +293,9 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
292293
if isinstance(file_input.source, Properties) or file_input.dataset_definition:
293294
normalized_inputs.append(file_input)
294295
continue
295-
296+
if isinstance(file_input.s3_input.s3_uri, (Parameter, Expression, Properties)):
297+
normalized_inputs.append(file_input)
298+
continue
296299
# If the source is a local path, upload it to S3
297300
# and save the S3 uri in the ProcessingInput source.
298301
parse_result = urlparse(file_input.s3_input.s3_uri)
@@ -340,8 +343,7 @@ def _normalize_outputs(self, outputs=None):
340343
# Generate a name for the ProcessingOutput if it doesn't have one.
341344
if output.output_name is None:
342345
output.output_name = "output-{}".format(count)
343-
# if the output's destination is a workflow expression, do no normalization
344-
if isinstance(output.destination, Expression):
346+
if isinstance(output.destination, (Parameter, Expression, Properties)):
345347
normalized_outputs.append(output)
346348
continue
347349
# If the output's destination is not an s3_uri, create one.
@@ -1099,7 +1101,7 @@ def _create_s3_input(self):
10991101
self.s3_data_type = self.s3_input.s3_data_type
11001102
self.s3_input_mode = self.s3_input.s3_input_mode
11011103
self.s3_data_distribution_type = self.s3_input.s3_data_distribution_type
1102-
elif self.source and self.destination:
1104+
elif self.source is not None and self.destination is not None:
11031105
self.s3_input = S3Input(
11041106
s3_uri=self.source,
11051107
local_path=self.destination,

src/sagemaker/workflow/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def create(
8383
8484
Args:
8585
role_arn (str): The role arn that is assumed by the pipeline to create step artifacts.
86-
pipeline_description (str): A description of the pipeline.
86+
description (str): A description of the pipeline.
8787
experiment_name (str): The name of the experiment.
8888
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
8989
tags.

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sagemaker.network import NetworkConfig
3636
from sagemaker.transformer import Transformer
3737
from sagemaker.workflow.properties import Properties
38+
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
3839
from sagemaker.workflow.steps import (
3940
ProcessingStep,
4041
Step,
@@ -112,16 +113,27 @@ def test_custom_step():
112113

113114

114115
def test_training_step(sagemaker_session):
116+
instance_type_parameter = ParameterString(name="InstanceType", default_value="c4.4xlarge")
117+
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
118+
data_source_uri_parameter = ParameterString(
119+
name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest"
120+
)
121+
training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5)
122+
training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500)
115123
estimator = Estimator(
116124
image_uri=IMAGE_URI,
117125
role=ROLE,
118-
instance_count=1,
119-
instance_type="c4.4xlarge",
126+
instance_count=instance_count_parameter,
127+
instance_type=instance_type_parameter,
120128
profiler_config=ProfilerConfig(system_monitor_interval_millis=500),
129+
hyperparameters={
130+
"batch-size": training_batch_size_parameter,
131+
"epochs": training_epochs_parameter,
132+
},
121133
rules=[],
122134
sagemaker_session=sagemaker_session,
123135
)
124-
inputs = TrainingInput(f"s3://{BUCKET}/train_manifest")
136+
inputs = TrainingInput(s3_data=data_source_uri_parameter)
125137
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
126138
step = TrainingStep(
127139
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
@@ -131,22 +143,26 @@ def test_training_step(sagemaker_session):
131143
"Type": "Training",
132144
"Arguments": {
133145
"AlgorithmSpecification": {"TrainingImage": IMAGE_URI, "TrainingInputMode": "File"},
146+
"HyperParameters": {
147+
"batch-size": training_batch_size_parameter,
148+
"epochs": training_epochs_parameter,
149+
},
134150
"InputDataConfig": [
135151
{
136152
"ChannelName": "training",
137153
"DataSource": {
138154
"S3DataSource": {
139155
"S3DataDistributionType": "FullyReplicated",
140156
"S3DataType": "S3Prefix",
141-
"S3Uri": f"s3://{BUCKET}/train_manifest",
157+
"S3Uri": data_source_uri_parameter,
142158
}
143159
},
144160
}
145161
],
146162
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
147163
"ResourceConfig": {
148-
"InstanceCount": 1,
149-
"InstanceType": "c4.4xlarge",
164+
"InstanceCount": instance_count_parameter,
165+
"InstanceType": instance_type_parameter,
150166
"VolumeSizeInGB": 30,
151167
},
152168
"RoleArn": ROLE,
@@ -162,16 +178,21 @@ def test_training_step(sagemaker_session):
162178

163179

164180
def test_processing_step(sagemaker_session):
181+
processing_input_data_uri_parameter = ParameterString(
182+
name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest"
183+
)
184+
instance_type_parameter = ParameterString(name="InstanceType", default_value="ml.m4.4xlarge")
185+
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
165186
processor = Processor(
166187
image_uri=IMAGE_URI,
167188
role=ROLE,
168-
instance_count=1,
169-
instance_type="ml.m4.4xlarge",
189+
instance_count=instance_count_parameter,
190+
instance_type=instance_type_parameter,
170191
sagemaker_session=sagemaker_session,
171192
)
172193
inputs = [
173194
ProcessingInput(
174-
source=f"s3://{BUCKET}/processing_manifest",
195+
source=processing_input_data_uri_parameter,
175196
destination="processing_manifest",
176197
)
177198
]
@@ -198,14 +219,14 @@ def test_processing_step(sagemaker_session):
198219
"S3DataDistributionType": "FullyReplicated",
199220
"S3DataType": "S3Prefix",
200221
"S3InputMode": "File",
201-
"S3Uri": "s3://my-bucket/processing_manifest",
222+
"S3Uri": processing_input_data_uri_parameter,
202223
},
203224
}
204225
],
205226
"ProcessingResources": {
206227
"ClusterConfig": {
207-
"InstanceCount": 1,
208-
"InstanceType": "ml.m4.4xlarge",
228+
"InstanceCount": instance_count_parameter,
229+
"InstanceType": instance_type_parameter,
209230
"VolumeSizeInGB": 30,
210231
}
211232
},

0 commit comments

Comments
 (0)