35
35
from sagemaker .network import NetworkConfig
36
36
from sagemaker .transformer import Transformer
37
37
from sagemaker .workflow .properties import Properties
38
+ from sagemaker .workflow .parameters import ParameterString , ParameterInteger
38
39
from sagemaker .workflow .steps import (
39
40
ProcessingStep ,
40
41
Step ,
@@ -112,16 +113,27 @@ def test_custom_step():
112
113
113
114
114
115
def test_training_step (sagemaker_session ):
116
+ instance_type_parameter = ParameterString (name = "InstanceType" , default_value = "c4.4xlarge" )
117
+ instance_count_parameter = ParameterInteger (name = "InstanceCount" , default_value = 1 )
118
+ data_source_uri_parameter = ParameterString (
119
+ name = "DataSourceS3Uri" , default_value = f"s3://{ BUCKET } /train_manifest"
120
+ )
121
+ training_epochs_parameter = ParameterInteger (name = "TrainingEpochs" , default_value = 5 )
122
+ training_batch_size_parameter = ParameterInteger (name = "TrainingBatchSize" , default_value = 500 )
115
123
estimator = Estimator (
116
124
image_uri = IMAGE_URI ,
117
125
role = ROLE ,
118
- instance_count = 1 ,
119
- instance_type = "c4.4xlarge" ,
126
+ instance_count = instance_count_parameter ,
127
+ instance_type = instance_type_parameter ,
120
128
profiler_config = ProfilerConfig (system_monitor_interval_millis = 500 ),
129
+ hyperparameters = {
130
+ "batch-size" : training_batch_size_parameter ,
131
+ "epochs" : training_epochs_parameter ,
132
+ },
121
133
rules = [],
122
134
sagemaker_session = sagemaker_session ,
123
135
)
124
- inputs = TrainingInput (f"s3:// { BUCKET } /train_manifest" )
136
+ inputs = TrainingInput (s3_data = data_source_uri_parameter )
125
137
cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
126
138
step = TrainingStep (
127
139
name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = cache_config
@@ -131,22 +143,26 @@ def test_training_step(sagemaker_session):
131
143
"Type" : "Training" ,
132
144
"Arguments" : {
133
145
"AlgorithmSpecification" : {"TrainingImage" : IMAGE_URI , "TrainingInputMode" : "File" },
146
+ "HyperParameters" : {
147
+ "batch-size" : training_batch_size_parameter ,
148
+ "epochs" : training_epochs_parameter ,
149
+ },
134
150
"InputDataConfig" : [
135
151
{
136
152
"ChannelName" : "training" ,
137
153
"DataSource" : {
138
154
"S3DataSource" : {
139
155
"S3DataDistributionType" : "FullyReplicated" ,
140
156
"S3DataType" : "S3Prefix" ,
141
- "S3Uri" : f"s3:// { BUCKET } /train_manifest" ,
157
+ "S3Uri" : data_source_uri_parameter ,
142
158
}
143
159
},
144
160
}
145
161
],
146
162
"OutputDataConfig" : {"S3OutputPath" : f"s3://{ BUCKET } /" },
147
163
"ResourceConfig" : {
148
- "InstanceCount" : 1 ,
149
- "InstanceType" : "c4.4xlarge" ,
164
+ "InstanceCount" : instance_count_parameter ,
165
+ "InstanceType" : instance_type_parameter ,
150
166
"VolumeSizeInGB" : 30 ,
151
167
},
152
168
"RoleArn" : ROLE ,
@@ -162,16 +178,21 @@ def test_training_step(sagemaker_session):
162
178
163
179
164
180
def test_processing_step (sagemaker_session ):
181
+ processing_input_data_uri_parameter = ParameterString (
182
+ name = "ProcessingInputDataUri" , default_value = f"s3://{ BUCKET } /processing_manifest"
183
+ )
184
+ instance_type_parameter = ParameterString (name = "InstanceType" , default_value = "ml.m4.4xlarge" )
185
+ instance_count_parameter = ParameterInteger (name = "InstanceCount" , default_value = 1 )
165
186
processor = Processor (
166
187
image_uri = IMAGE_URI ,
167
188
role = ROLE ,
168
- instance_count = 1 ,
169
- instance_type = "ml.m4.4xlarge" ,
189
+ instance_count = instance_count_parameter ,
190
+ instance_type = instance_type_parameter ,
170
191
sagemaker_session = sagemaker_session ,
171
192
)
172
193
inputs = [
173
194
ProcessingInput (
174
- source = f"s3:// { BUCKET } /processing_manifest" ,
195
+ source = processing_input_data_uri_parameter ,
175
196
destination = "processing_manifest" ,
176
197
)
177
198
]
@@ -198,14 +219,14 @@ def test_processing_step(sagemaker_session):
198
219
"S3DataDistributionType" : "FullyReplicated" ,
199
220
"S3DataType" : "S3Prefix" ,
200
221
"S3InputMode" : "File" ,
201
- "S3Uri" : "s3://my-bucket/processing_manifest" ,
222
+ "S3Uri" : processing_input_data_uri_parameter ,
202
223
},
203
224
}
204
225
],
205
226
"ProcessingResources" : {
206
227
"ClusterConfig" : {
207
- "InstanceCount" : 1 ,
208
- "InstanceType" : "ml.m4.4xlarge" ,
228
+ "InstanceCount" : instance_count_parameter ,
229
+ "InstanceType" : instance_type_parameter ,
209
230
"VolumeSizeInGB" : 30 ,
210
231
}
211
232
},
0 commit comments