Skip to content

Commit 9091973

Browse files
author
Payton Staub
committed
Allow hyperparameters in Tensorflow estimator to be parameterized for a pipeline
1 parent a058347 commit 9091973

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2443,7 +2443,13 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
24432443
@staticmethod
24442444
def _json_encode_hyperparameters(hyperparameters):
24452445
"""Placeholder docstring"""
2446-
return {str(k): json.dumps(v) for (k, v) in hyperparameters.items()}
2446+
current_hyperparameters = hyperparameters
2447+
if current_hyperparameters is not None:
2448+
hyperparameters = {
2449+
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v))
2450+
for (k, v) in current_hyperparameters.items()
2451+
}
2452+
return hyperparameters
24472453

24482454
@classmethod
24492455
def _update_init_params(cls, hp, tf_arguments):

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pytest
1717
import sagemaker
18+
import os
1819

1920
from mock import (
2021
Mock,
@@ -24,6 +25,7 @@
2425

2526
from sagemaker.debugger import ProfilerConfig
2627
from sagemaker.estimator import Estimator
28+
from sagemaker.tensorflow import TensorFlow
2729
from sagemaker.inputs import TrainingInput, TransformInput, CreateModelInput
2830
from sagemaker.model import Model
2931
from sagemaker.processing import (
@@ -45,6 +47,10 @@
4547
CreateModelStep,
4648
CacheConfig,
4749
)
50+
from tests.unit import DATA_DIR
51+
52+
SCRIPT_FILE = "dummy_script.py"
53+
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE)
4854

4955
REGION = "us-west-2"
5056
BUCKET = "my-bucket"
@@ -112,7 +118,7 @@ def test_custom_step():
112118
assert step.to_request() == {"Name": "MyStep", "Type": "Training", "Arguments": dict()}
113119

114120

115-
def test_training_step(sagemaker_session):
121+
def test_training_step_base_estimator(sagemaker_session):
116122
instance_type_parameter = ParameterString(name="InstanceType", default_value="c4.4xlarge")
117123
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
118124
data_source_uri_parameter = ParameterString(
@@ -177,6 +183,100 @@ def test_training_step(sagemaker_session):
177183
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
178184

179185

186+
def test_training_step_tensorflow(sagemaker_session):
187+
instance_type_parameter = ParameterString(name="InstanceType", default_value="ml.p3.16xlarge")
188+
instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)
189+
data_source_uri_parameter = ParameterString(
190+
name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest"
191+
)
192+
training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5)
193+
training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500)
194+
estimator = TensorFlow(
195+
entry_point=os.path.join(DATA_DIR, SCRIPT_FILE),
196+
role=ROLE,
197+
model_dir=False,
198+
image_uri=IMAGE_URI,
199+
source_dir="s3://mybucket/source",
200+
framework_version="2.4.1",
201+
py_version="py37",
202+
instance_count=instance_count_parameter,
203+
instance_type=instance_type_parameter,
204+
sagemaker_session=sagemaker_session,
205+
# subnets=subnets,
206+
hyperparameters={
207+
"batch-size": training_batch_size_parameter,
208+
"epochs": training_epochs_parameter,
209+
},
210+
# security_group_ids=security_group_ids,
211+
debugger_hook_config=False,
212+
# Training using SMDataParallel Distributed Training Framework
213+
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
214+
)
215+
216+
inputs = TrainingInput(s3_data=data_source_uri_parameter)
217+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
218+
step = TrainingStep(
219+
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
220+
)
221+
step_request = step.to_request()
222+
step_request['Arguments']['HyperParameters'].pop('sagemaker_job_name', None)
223+
step_request['Arguments'].pop('ProfilerRuleConfigurations', None)
224+
assert step_request == {
225+
'Name':'MyTrainingStep',
226+
'Type':'Training',
227+
'Arguments':{
228+
'AlgorithmSpecification':{
229+
'TrainingInputMode':'File',
230+
'TrainingImage':'fakeimage',
231+
'EnableSageMakerMetricsTimeSeries':True
232+
},
233+
'OutputDataConfig':{
234+
'S3OutputPath':'s3://my-bucket/'
235+
},
236+
'StoppingCondition':{
237+
'MaxRuntimeInSeconds':86400
238+
},
239+
'ResourceConfig':{
240+
'InstanceCount':instance_count_parameter,
241+
'InstanceType':instance_type_parameter,
242+
'VolumeSizeInGB':30
243+
},
244+
'RoleArn':'DummyRole',
245+
'InputDataConfig':[
246+
{
247+
'DataSource':{
248+
'S3DataSource':{
249+
'S3DataType':'S3Prefix',
250+
'S3Uri':data_source_uri_parameter,
251+
'S3DataDistributionType':'FullyReplicated'
252+
}
253+
},
254+
'ChannelName':'training'
255+
}
256+
],
257+
'HyperParameters':{
258+
'batch-size':training_batch_size_parameter,
259+
'epochs':training_epochs_parameter,
260+
'sagemaker_submit_directory':'"s3://mybucket/source"',
261+
'sagemaker_program':'"/Volumes/Unix/workplace/pstaub/sagemaker-python-sdk/tests/unit/../data/dummy_script.py"',
262+
'sagemaker_container_log_level':'20',
263+
'sagemaker_region':'"us-west-2"',
264+
'sagemaker_distributed_dataparallel_enabled':'true',
265+
'sagemaker_instance_type':instance_type_parameter,
266+
'sagemaker_distributed_dataparallel_custom_mpi_options':'""'
267+
},
268+
'ProfilerConfig':{
269+
'S3OutputPath':'s3://my-bucket/'
270+
}
271+
},
272+
'CacheConfig':{
273+
'Enabled':True,
274+
'ExpireAfter':'PT1H'
275+
}
276+
}
277+
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
278+
279+
180280
def test_processing_step(sagemaker_session):
181281
processing_input_data_uri_parameter = ParameterString(
182282
name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest"

0 commit comments

Comments
 (0)