31
31
)
32
32
from sagemaker .transformer import Transformer
33
33
from sagemaker .workflow .properties import Properties
34
+ from sagemaker .workflow .parameters import ParameterString , ParameterInteger
34
35
from sagemaker .workflow .steps import (
35
36
ProcessingStep ,
36
37
Step ,
@@ -108,16 +109,27 @@ def test_custom_step():
108
109
109
110
110
111
def test_training_step (sagemaker_session ):
112
+ instance_type_parameter = ParameterString (name = "InstanceType" , default_value = "c4.4xlarge" )
113
+ instance_count_parameter = ParameterInteger (name = "InstanceCount" , default_value = 1 )
114
+ data_source_uri_parameter = ParameterString (
115
+ name = "DataSourceS3Uri" , default_value = f"s3://{ BUCKET } /train_manifest"
116
+ )
117
+ training_epochs_parameter = ParameterInteger (name = "TrainingEpochs" , default_value = 5 )
118
+ training_batch_size_parameter = ParameterInteger (name = "TrainingBatchSize" , default_value = 500 )
111
119
estimator = Estimator (
112
120
image_uri = IMAGE_URI ,
113
121
role = ROLE ,
114
- instance_count = 1 ,
115
- instance_type = "c4.4xlarge" ,
122
+ instance_count = instance_count_parameter ,
123
+ instance_type = instance_type_parameter ,
116
124
profiler_config = ProfilerConfig (system_monitor_interval_millis = 500 ),
125
+ hyperparameters = {
126
+ "batch-size" : training_batch_size_parameter ,
127
+ "epochs" : training_epochs_parameter ,
128
+ },
117
129
rules = [],
118
130
sagemaker_session = sagemaker_session ,
119
131
)
120
- inputs = TrainingInput (f"s3:// { BUCKET } /train_manifest" )
132
+ inputs = TrainingInput (s3_data = data_source_uri_parameter )
121
133
cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
122
134
step = TrainingStep (
123
135
name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = cache_config
@@ -127,22 +139,26 @@ def test_training_step(sagemaker_session):
127
139
"Type" : "Training" ,
128
140
"Arguments" : {
129
141
"AlgorithmSpecification" : {"TrainingImage" : IMAGE_URI , "TrainingInputMode" : "File" },
142
+ "HyperParameters" : {
143
+ "batch-size" : training_batch_size_parameter ,
144
+ "epochs" : training_epochs_parameter ,
145
+ },
130
146
"InputDataConfig" : [
131
147
{
132
148
"ChannelName" : "training" ,
133
149
"DataSource" : {
134
150
"S3DataSource" : {
135
151
"S3DataDistributionType" : "FullyReplicated" ,
136
152
"S3DataType" : "S3Prefix" ,
137
- "S3Uri" : f"s3:// { BUCKET } /train_manifest" ,
153
+ "S3Uri" : data_source_uri_parameter ,
138
154
}
139
155
},
140
156
}
141
157
],
142
158
"OutputDataConfig" : {"S3OutputPath" : f"s3://{ BUCKET } /" },
143
159
"ResourceConfig" : {
144
- "InstanceCount" : 1 ,
145
- "InstanceType" : "c4.4xlarge" ,
160
+ "InstanceCount" : instance_count_parameter ,
161
+ "InstanceType" : instance_type_parameter ,
146
162
"VolumeSizeInGB" : 30 ,
147
163
},
148
164
"RoleArn" : ROLE ,
@@ -158,16 +174,21 @@ def test_training_step(sagemaker_session):
158
174
159
175
160
176
def test_processing_step (sagemaker_session ):
177
+ processing_input_data_uri_parameter = ParameterString (
178
+ name = "ProcessingInputDataUri" , default_value = f"s3://{ BUCKET } /processing_manifest"
179
+ )
180
+ instance_type_parameter = ParameterString (name = "InstanceType" , default_value = "ml.m4.4xlarge" )
181
+ instance_count_parameter = ParameterInteger (name = "InstanceCount" , default_value = 1 )
161
182
processor = Processor (
162
183
image_uri = IMAGE_URI ,
163
184
role = ROLE ,
164
- instance_count = 1 ,
165
- instance_type = "ml.m4.4xlarge" ,
185
+ instance_count = instance_count_parameter ,
186
+ instance_type = instance_type_parameter ,
166
187
sagemaker_session = sagemaker_session ,
167
188
)
168
189
inputs = [
169
190
ProcessingInput (
170
- source = f"s3:// { BUCKET } /processing_manifest" ,
191
+ source = processing_input_data_uri_parameter ,
171
192
destination = "processing_manifest" ,
172
193
)
173
194
]
@@ -194,14 +215,14 @@ def test_processing_step(sagemaker_session):
194
215
"S3DataDistributionType" : "FullyReplicated" ,
195
216
"S3DataType" : "S3Prefix" ,
196
217
"S3InputMode" : "File" ,
197
- "S3Uri" : "s3://my-bucket/processing_manifest" ,
218
+ "S3Uri" : processing_input_data_uri_parameter ,
198
219
},
199
220
}
200
221
],
201
222
"ProcessingResources" : {
202
223
"ClusterConfig" : {
203
- "InstanceCount" : 1 ,
204
- "InstanceType" : "ml.m4.4xlarge" ,
224
+ "InstanceCount" : instance_count_parameter ,
225
+ "InstanceType" : instance_type_parameter ,
205
226
"VolumeSizeInGB" : 30 ,
206
227
}
207
228
},
0 commit comments