Skip to content

Commit 1b6f4ac

Browse files
author
Namrata Madan
committed
fix: use workflow parameters in training hyperparameters (#2114) (#2115)
1 parent 334f942 commit 1b6f4ac

File tree

4 files changed

+48
-19
lines changed

4 files changed

+48
-19
lines changed

src/sagemaker/estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252
_region_supports_profiler,
5353
get_mp_parameters,
5454
)
55+
from sagemaker.workflow.properties import Properties
56+
from sagemaker.workflow.parameters import Parameter
57+
from sagemaker.workflow.entities import Expression
5558
from sagemaker.inputs import TrainingInput
5659
from sagemaker.job import _Job
5760
from sagemaker.local import LocalSession
@@ -1456,7 +1459,10 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
14561459

14571460
current_hyperparameters = estimator.hyperparameters()
14581461
if current_hyperparameters is not None:
1459-
hyperparameters = {str(k): str(v) for (k, v) in current_hyperparameters.items()}
1462+
hyperparameters = {
1463+
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v))
1464+
for (k, v) in current_hyperparameters.items()
1465+
}
14601466

14611467
train_args = config.copy()
14621468
train_args["input_mode"] = estimator.input_mode

src/sagemaker/processing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sagemaker.session import Session
3232
from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
3333
from sagemaker.workflow.properties import Properties
34+
from sagemaker.workflow.parameters import Parameter
3435
from sagemaker.workflow.entities import Expression
3536
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
3637
from sagemaker.apiutils._base_types import ApiObject
@@ -291,7 +292,9 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
291292
if isinstance(file_input.source, Properties) or file_input.dataset_definition:
292293
normalized_inputs.append(file_input)
293294
continue
294-
295+
if isinstance(file_input.s3_input.s3_uri, (Parameter, Expression, Properties)):
296+
normalized_inputs.append(file_input)
297+
continue
295298
# If the source is a local path, upload it to S3
296299
# and save the S3 uri in the ProcessingInput source.
297300
parse_result = urlparse(file_input.s3_input.s3_uri)
@@ -339,8 +342,7 @@ def _normalize_outputs(self, outputs=None):
339342
# Generate a name for the ProcessingOutput if it doesn't have one.
340343
if output.output_name is None:
341344
output.output_name = "output-{}".format(count)
342-
# if the output's destination is a workflow expression, do no normalization
343-
if isinstance(output.destination, Expression):
345+
if isinstance(output.destination, (Parameter, Expression, Properties)):
344346
normalized_outputs.append(output)
345347
continue
346348
# If the output's destination is not an s3_uri, create one.
@@ -1070,7 +1072,7 @@ def _create_s3_input(self):
10701072
self.s3_data_type = self.s3_input.s3_data_type
10711073
self.s3_input_mode = self.s3_input.s3_input_mode
10721074
self.s3_data_distribution_type = self.s3_input.s3_data_distribution_type
1073-
elif self.source and self.destination:
1075+
elif self.source is not None and self.destination is not None:
10741076
self.s3_input = S3Input(
10751077
s3_uri=self.source,
10761078
local_path=self.destination,

src/sagemaker/workflow/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def to_request(self) -> RequestType:
7575
def create(
7676
self,
7777
role_arn: str,
78-
description: str = None,
78+
pipeline_description: str = None,
7979
experiment_name: str = None,
8080
tags: List[Dict[str, str]] = None,
8181
) -> Dict[str, Any]:
@@ -93,7 +93,7 @@ def create(
9393
"""
9494
tags = _append_project_tags(tags)
9595

96-
kwargs = self._create_args(role_arn, description)
96+
kwargs = self._create_args(role_arn, pipeline_description)
9797
update_args(
9898
kwargs,
9999
ExperimentName=experiment_name,

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from sagemaker.transformer import Transformer
3333
from sagemaker.workflow.properties import Properties
34+
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
3435
from sagemaker.workflow.steps import (
3536
ProcessingStep,
3637
Step,
@@ -108,16 +109,27 @@ def test_custom_step():
108109

109110

110111
def test_training_step(sagemaker_session):
112+
instance_type_parameter = ParameterString(name="InstanceType", default_value="c4.4xlarge")
113+
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
114+
data_source_uri_parameter = ParameterString(
115+
name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest"
116+
)
117+
training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5)
118+
training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500)
111119
estimator = Estimator(
112120
image_uri=IMAGE_URI,
113121
role=ROLE,
114-
instance_count=1,
115-
instance_type="c4.4xlarge",
122+
instance_count=instance_count_parameter,
123+
instance_type=instance_type_parameter,
116124
profiler_config=ProfilerConfig(system_monitor_interval_millis=500),
125+
hyperparameters={
126+
"batch-size": training_batch_size_parameter,
127+
"epochs": training_epochs_parameter,
128+
},
117129
rules=[],
118130
sagemaker_session=sagemaker_session,
119131
)
120-
inputs = TrainingInput(f"s3://{BUCKET}/train_manifest")
132+
inputs = TrainingInput(s3_data=data_source_uri_parameter)
121133
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
122134
step = TrainingStep(
123135
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
@@ -127,22 +139,26 @@ def test_training_step(sagemaker_session):
127139
"Type": "Training",
128140
"Arguments": {
129141
"AlgorithmSpecification": {"TrainingImage": IMAGE_URI, "TrainingInputMode": "File"},
142+
"HyperParameters": {
143+
"batch-size": training_batch_size_parameter,
144+
"epochs": training_epochs_parameter,
145+
},
130146
"InputDataConfig": [
131147
{
132148
"ChannelName": "training",
133149
"DataSource": {
134150
"S3DataSource": {
135151
"S3DataDistributionType": "FullyReplicated",
136152
"S3DataType": "S3Prefix",
137-
"S3Uri": f"s3://{BUCKET}/train_manifest",
153+
"S3Uri": data_source_uri_parameter,
138154
}
139155
},
140156
}
141157
],
142158
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
143159
"ResourceConfig": {
144-
"InstanceCount": 1,
145-
"InstanceType": "c4.4xlarge",
160+
"InstanceCount": instance_count_parameter,
161+
"InstanceType": instance_type_parameter,
146162
"VolumeSizeInGB": 30,
147163
},
148164
"RoleArn": ROLE,
@@ -158,16 +174,21 @@ def test_training_step(sagemaker_session):
158174

159175

160176
def test_processing_step(sagemaker_session):
177+
processing_input_data_uri_parameter = ParameterString(
178+
name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest"
179+
)
180+
instance_type_parameter = ParameterString(name="InstanceType", default_value="ml.m4.4xlarge")
181+
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
161182
processor = Processor(
162183
image_uri=IMAGE_URI,
163184
role=ROLE,
164-
instance_count=1,
165-
instance_type="ml.m4.4xlarge",
185+
instance_count=instance_count_parameter,
186+
instance_type=instance_type_parameter,
166187
sagemaker_session=sagemaker_session,
167188
)
168189
inputs = [
169190
ProcessingInput(
170-
source=f"s3://{BUCKET}/processing_manifest",
191+
source=processing_input_data_uri_parameter,
171192
destination="processing_manifest",
172193
)
173194
]
@@ -194,14 +215,14 @@ def test_processing_step(sagemaker_session):
194215
"S3DataDistributionType": "FullyReplicated",
195216
"S3DataType": "S3Prefix",
196217
"S3InputMode": "File",
197-
"S3Uri": "s3://my-bucket/processing_manifest",
218+
"S3Uri": processing_input_data_uri_parameter,
198219
},
199220
}
200221
],
201222
"ProcessingResources": {
202223
"ClusterConfig": {
203-
"InstanceCount": 1,
204-
"InstanceType": "ml.m4.4xlarge",
224+
"InstanceCount": instance_count_parameter,
225+
"InstanceType": instance_type_parameter,
205226
"VolumeSizeInGB": 30,
206227
}
207228
},

0 commit comments

Comments
 (0)