15
15
import os
16
16
17
17
import sagemaker
18
- from sagemaker import job , utils , model
18
+ from sagemaker import job , model , utils
19
19
from sagemaker .amazon import amazon_estimator
20
20
21
21
@@ -48,14 +48,19 @@ def prepare_framework(estimator, s3_operations):
48
48
estimator ._hyperparameters [model .SAGEMAKER_REGION_PARAM_NAME ] = estimator .sagemaker_session .boto_region_name
49
49
50
50
51
- def prepare_amazon_algorithm_estimator (estimator , inputs ):
51
+ def prepare_amazon_algorithm_estimator (estimator , inputs , mini_batch_size = None ):
52
52
""" Set up amazon algorithm estimator, adding the required `feature_dim` hyperparameter from training data.
53
53
54
54
Args:
55
55
estimator (sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
56
56
An estimator for a built-in Amazon algorithm to get information from and update.
57
- inputs (single or list of sagemaker.amazon.amazon_estimator.RecordSet):
58
- The training data, must be in RecordSet format.
57
+ inputs: The training data.
58
+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
59
+ Amazon :class:~`Record` objects serialized and stored in S3.
60
+ For use with an estimator for an Amazon algorithm.
61
+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
62
+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
63
+ a different channel of training data.
59
64
"""
60
65
if isinstance (inputs , list ):
61
66
for record in inputs :
@@ -66,22 +71,39 @@ def prepare_amazon_algorithm_estimator(estimator, inputs):
66
71
estimator .feature_dim = inputs .feature_dim
67
72
else :
68
73
raise TypeError ('Training data must be represented in RecordSet or list of RecordSets' )
74
+ estimator .mini_batch_size = mini_batch_size
69
75
70
76
71
- def training_config (estimator , inputs = None , job_name = None ): # noqa: C901 - suppress complexity warning for this method
72
- """Export Airflow training config from an estimator
77
+ def training_base_config (estimator , inputs = None , job_name = None , mini_batch_size = None ):
78
+ """Export Airflow base training config from an estimator
73
79
74
80
Args:
75
- estimator (sagemaker.estimator.EstimatroBase ):
81
+ estimator (sagemaker.estimator.EstimatorBase ):
76
82
The estimator to export training config from. Can be a BYO estimator,
77
83
Framework estimator or Amazon algorithm estimator.
78
- inputs (str, dict, single or list of sagemaker.amazon.amazon_estimator.RecordSet):
79
- The training data.
84
+ inputs: Information about the training data. Please refer to the ``fit()`` method of
85
+ the associated estimator, as this can take any of the following forms:
86
+
87
+ * (str) - The S3 location where training data is saved.
88
+ * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
89
+ training data, you can specify a dict mapping channel names
90
+ to strings or :func:`~sagemaker.session.s3_input` objects.
91
+ * (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
92
+ additional information about the training dataset. See :func:`sagemaker.session.s3_input`
93
+ for full details.
94
+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
95
+ Amazon :class:~`Record` objects serialized and stored in S3.
96
+ For use with an estimator for an Amazon algorithm.
97
+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
98
+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
99
+ a different channel of training data.
100
+
80
101
job_name (str): Specify a training job name if needed.
102
+ mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
103
+ Amazon algorithm. For other estimators, batch size should be specified in the estimator.
81
104
82
- Returns:
83
- A dict of training config that can be directly used by SageMakerTrainingOperator
84
- in Airflow.
105
+ Returns (dict):
106
+ Training config that can be directly used by SageMakerTrainingOperator in Airflow.
85
107
"""
86
108
default_bucket = estimator .sagemaker_session .default_bucket ()
87
109
s3_operations = {}
@@ -99,8 +121,7 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
99
121
prepare_framework (estimator , s3_operations )
100
122
101
123
elif isinstance (estimator , amazon_estimator .AmazonAlgorithmEstimatorBase ):
102
- prepare_amazon_algorithm_estimator (estimator , inputs )
103
-
124
+ prepare_amazon_algorithm_estimator (estimator , inputs , mini_batch_size )
104
125
job_config = job ._Job ._load_config (inputs , estimator , expand_role = False , validate_uri = False )
105
126
106
127
train_config = {
@@ -109,7 +130,6 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
109
130
'TrainingInputMode' : estimator .input_mode
110
131
},
111
132
'OutputDataConfig' : job_config ['output_config' ],
112
- 'TrainingJobName' : estimator ._current_job_name ,
113
133
'StoppingCondition' : job_config ['stop_condition' ],
114
134
'ResourceConfig' : job_config ['resource_config' ],
115
135
'RoleArn' : job_config ['role' ],
@@ -127,10 +147,125 @@ def training_config(estimator, inputs=None, job_name=None): # noqa: C901 - supp
127
147
if hyperparameters and len (hyperparameters ) > 0 :
128
148
train_config ['HyperParameters' ] = hyperparameters
129
149
130
- if estimator .tags is not None :
131
- train_config ['Tags' ] = estimator .tags
132
-
133
150
if s3_operations :
134
151
train_config ['S3Operations' ] = s3_operations
135
152
136
153
return train_config
154
+
155
+
156
+ def training_config (estimator , inputs = None , job_name = None , mini_batch_size = None ):
157
+ """Export Airflow training config from an estimator
158
+
159
+ Args:
160
+ estimator (sagemaker.estimator.EstimatorBase):
161
+ The estimator to export training config from. Can be a BYO estimator,
162
+ Framework estimator or Amazon algorithm estimator.
163
+ inputs: Information about the training data. Please refer to the ``fit()`` method of
164
+ the associated estimator, as this can take any of the following forms:
165
+
166
+ * (str) - The S3 location where training data is saved.
167
+ * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
168
+ training data, you can specify a dict mapping channel names
169
+ to strings or :func:`~sagemaker.session.s3_input` objects.
170
+ * (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
171
+ additional information about the training dataset. See :func:`sagemaker.session.s3_input`
172
+ for full details.
173
+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
174
+ Amazon :class:~`Record` objects serialized and stored in S3.
175
+ For use with an estimator for an Amazon algorithm.
176
+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
177
+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
178
+ a different channel of training data.
179
+
180
+ job_name (str): Specify a training job name if needed.
181
+ mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
182
+ Amazon algorithm. For other estimators, batch size should be specified in the estimator.
183
+
184
+ Returns (dict):
185
+ Training config that can be directly used by SageMakerTrainingOperator in Airflow.
186
+ """
187
+
188
+ train_config = training_base_config (estimator , inputs , job_name , mini_batch_size )
189
+
190
+ train_config ['TrainingJobName' ] = estimator ._current_job_name
191
+
192
+ if estimator .tags is not None :
193
+ train_config ['Tags' ] = estimator .tags
194
+
195
+ return train_config
196
+
197
+
198
+ def tuning_config (tuner , inputs , job_name = None ):
199
+ """Export Airflow tuning config from an estimator
200
+
201
+ Args:
202
+ tuner (sagemaker.tuner.HyperparameterTuner): The tuner to export tuning config from.
203
+ inputs: Information about the training data. Please refer to the ``fit()`` method of
204
+ the associated estimator in the tuner, as this can take any of the following forms:
205
+
206
+ * (str) - The S3 location where training data is saved.
207
+ * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
208
+ training data, you can specify a dict mapping channel names
209
+ to strings or :func:`~sagemaker.session.s3_input` objects.
210
+ * (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can provide
211
+ additional information about the training dataset. See :func:`sagemaker.session.s3_input`
212
+ for full details.
213
+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
214
+ Amazon :class:~`Record` objects serialized and stored in S3.
215
+ For use with an estimator for an Amazon algorithm.
216
+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
217
+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is
218
+ a different channel of training data.
219
+
220
+ job_name (str): Specify a tuning job name if needed.
221
+
222
+ Returns (dict):
223
+ Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
224
+ """
225
+ train_config = training_base_config (tuner .estimator , inputs )
226
+ hyperparameters = train_config .pop ('HyperParameters' , None )
227
+ s3_operations = train_config .pop ('S3Operations' , None )
228
+
229
+ if hyperparameters and len (hyperparameters ) > 0 :
230
+ tuner .static_hyperparameters = \
231
+ {utils .to_str (k ): utils .to_str (v ) for (k , v ) in hyperparameters .items ()}
232
+
233
+ if job_name is not None :
234
+ tuner ._current_job_name = job_name
235
+ else :
236
+ base_name = tuner .base_tuning_job_name or utils .base_name_from_image (tuner .estimator .train_image ())
237
+ tuner ._current_job_name = utils .airflow_name_from_base (base_name , tuner .TUNING_JOB_NAME_MAX_LENGTH , True )
238
+
239
+ for hyperparameter_name in tuner ._hyperparameter_ranges .keys ():
240
+ tuner .static_hyperparameters .pop (hyperparameter_name , None )
241
+
242
+ train_config ['StaticHyperParameters' ] = tuner .static_hyperparameters
243
+
244
+ tune_config = {
245
+ 'HyperParameterTuningJobName' : tuner ._current_job_name ,
246
+ 'HyperParameterTuningJobConfig' : {
247
+ 'Strategy' : tuner .strategy ,
248
+ 'HyperParameterTuningJobObjective' : {
249
+ 'Type' : tuner .objective_type ,
250
+ 'MetricName' : tuner .objective_metric_name ,
251
+ },
252
+ 'ResourceLimits' : {
253
+ 'MaxNumberOfTrainingJobs' : tuner .max_jobs ,
254
+ 'MaxParallelTrainingJobs' : tuner .max_parallel_jobs ,
255
+ },
256
+ 'ParameterRanges' : tuner .hyperparameter_ranges (),
257
+ },
258
+ 'TrainingJobDefinition' : train_config
259
+ }
260
+
261
+ if tuner .metric_definitions is not None :
262
+ tune_config ['TrainingJobDefinition' ]['AlgorithmSpecification' ]['MetricDefinitions' ] = \
263
+ tuner .metric_definitions
264
+
265
+ if tuner .tags is not None :
266
+ tune_config ['Tags' ] = tuner .tags
267
+
268
+ if s3_operations is not None :
269
+ tune_config ['S3Operations' ] = s3_operations
270
+
271
+ return tune_config
0 commit comments