Skip to content

Commit cdb633b

Browse files
staubhpPayton Staub
andauthored
fix: Allow hyperparameters in Tensorflow estimator to be parameterized (#2296)
* Allow hyperparameters in Tensorflow estimator to be parameterized for a pipeline * Fix linter errors Co-authored-by: Payton Staub <[email protected]>
1 parent a058347 commit cdb633b

File tree

2 files changed

+99
-2
lines changed

2 files changed

+99
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2443,7 +2443,13 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
24432443
@staticmethod
24442444
def _json_encode_hyperparameters(hyperparameters):
24452445
"""Placeholder docstring"""
2446-
return {str(k): json.dumps(v) for (k, v) in hyperparameters.items()}
2446+
current_hyperparameters = hyperparameters
2447+
if current_hyperparameters is not None:
2448+
hyperparameters = {
2449+
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v))
2450+
for (k, v) in current_hyperparameters.items()
2451+
}
2452+
return hyperparameters
24472453

24482454
@classmethod
24492455
def _update_init_params(cls, hp, tf_arguments):

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pytest
1717
import sagemaker
18+
import os
1819

1920
from mock import (
2021
Mock,
@@ -24,6 +25,7 @@
2425

2526
from sagemaker.debugger import ProfilerConfig
2627
from sagemaker.estimator import Estimator
28+
from sagemaker.tensorflow import TensorFlow
2729
from sagemaker.inputs import TrainingInput, TransformInput, CreateModelInput
2830
from sagemaker.model import Model
2931
from sagemaker.processing import (
@@ -45,6 +47,10 @@
4547
CreateModelStep,
4648
CacheConfig,
4749
)
50+
from tests.unit import DATA_DIR
51+
52+
SCRIPT_FILE = "dummy_script.py"
53+
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE)
4854

4955
REGION = "us-west-2"
5056
BUCKET = "my-bucket"
@@ -112,7 +118,7 @@ def test_custom_step():
112118
assert step.to_request() == {"Name": "MyStep", "Type": "Training", "Arguments": dict()}
113119

114120

115-
def test_training_step(sagemaker_session):
121+
def test_training_step_base_estimator(sagemaker_session):
116122
instance_type_parameter = ParameterString(name="InstanceType", default_value="c4.4xlarge")
117123
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
118124
data_source_uri_parameter = ParameterString(
@@ -177,6 +183,91 @@ def test_training_step(sagemaker_session):
177183
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
178184

179185

186+
def test_training_step_tensorflow(sagemaker_session):
187+
instance_type_parameter = ParameterString(name="InstanceType", default_value="ml.p3.16xlarge")
188+
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
189+
data_source_uri_parameter = ParameterString(
190+
name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest"
191+
)
192+
training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5)
193+
training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500)
194+
estimator = TensorFlow(
195+
entry_point=os.path.join(DATA_DIR, SCRIPT_FILE),
196+
role=ROLE,
197+
model_dir=False,
198+
image_uri=IMAGE_URI,
199+
source_dir="s3://mybucket/source",
200+
framework_version="2.4.1",
201+
py_version="py37",
202+
instance_count=instance_count_parameter,
203+
instance_type=instance_type_parameter,
204+
sagemaker_session=sagemaker_session,
205+
# subnets=subnets,
206+
hyperparameters={
207+
"batch-size": training_batch_size_parameter,
208+
"epochs": training_epochs_parameter,
209+
},
210+
# security_group_ids=security_group_ids,
211+
debugger_hook_config=False,
212+
# Training using SMDataParallel Distributed Training Framework
213+
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
214+
)
215+
216+
inputs = TrainingInput(s3_data=data_source_uri_parameter)
217+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
218+
step = TrainingStep(
219+
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
220+
)
221+
step_request = step.to_request()
222+
step_request["Arguments"]["HyperParameters"].pop("sagemaker_job_name", None)
223+
step_request["Arguments"]["HyperParameters"].pop("sagemaker_program", None)
224+
step_request["Arguments"].pop("ProfilerRuleConfigurations", None)
225+
assert step_request == {
226+
"Name": "MyTrainingStep",
227+
"Type": "Training",
228+
"Arguments": {
229+
"AlgorithmSpecification": {
230+
"TrainingInputMode": "File",
231+
"TrainingImage": "fakeimage",
232+
"EnableSageMakerMetricsTimeSeries": True,
233+
},
234+
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
235+
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
236+
"ResourceConfig": {
237+
"InstanceCount": instance_count_parameter,
238+
"InstanceType": instance_type_parameter,
239+
"VolumeSizeInGB": 30,
240+
},
241+
"RoleArn": "DummyRole",
242+
"InputDataConfig": [
243+
{
244+
"DataSource": {
245+
"S3DataSource": {
246+
"S3DataType": "S3Prefix",
247+
"S3Uri": data_source_uri_parameter,
248+
"S3DataDistributionType": "FullyReplicated",
249+
}
250+
},
251+
"ChannelName": "training",
252+
}
253+
],
254+
"HyperParameters": {
255+
"batch-size": training_batch_size_parameter,
256+
"epochs": training_epochs_parameter,
257+
"sagemaker_submit_directory": '"s3://mybucket/source"',
258+
"sagemaker_container_log_level": "20",
259+
"sagemaker_region": '"us-west-2"',
260+
"sagemaker_distributed_dataparallel_enabled": "true",
261+
"sagemaker_instance_type": instance_type_parameter,
262+
"sagemaker_distributed_dataparallel_custom_mpi_options": '""',
263+
},
264+
"ProfilerConfig": {"S3OutputPath": "s3://my-bucket/"},
265+
},
266+
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
267+
}
268+
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
269+
270+
180271
def test_processing_step(sagemaker_session):
181272
processing_input_data_uri_parameter = ParameterString(
182273
name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest"

0 commit comments

Comments
 (0)