@@ -239,8 +239,8 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None)
239
239
return train_config
240
240
241
241
242
- def tuning_config (tuner , inputs , job_name = None ):
243
- """Export Airflow tuning config from an estimator
242
+ def tuning_config (tuner , inputs , job_name = None , include_cls_metadata = False , mini_batch_size = None ):
243
+ """Export Airflow tuning config from a HyperparameterTuner
244
244
245
245
Args:
246
246
tuner (sagemaker.tuner.HyperparameterTuner): The tuner to export tuning
@@ -266,64 +266,187 @@ def tuning_config(tuner, inputs, job_name=None):
266
266
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
267
267
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
268
268
where each instance is a different channel of training data.
269
+
270
+ * (dict[str, one the forms above]): Required by only tuners created via
271
+ the factory method ``HyperparameterTuner.create()``. The keys should be the
272
+ same estimator names as keys for the ``estimator_dict`` argument of the
273
+ ``HyperparameterTuner.create()`` method.
269
274
job_name (str): Specify a tuning job name if needed.
275
+ include_cls_metadata: It can take one of the following two forms.
276
+
277
+ * (bool) - Whether or not the hyperparameter tuning job should include information
278
+ about the estimator class (default: False). This information is passed as a
279
+ hyperparameter, so if the algorithm you are using cannot handle unknown
280
+ hyperparameters (e.g. an Amazon SageMaker built-in algorithm that does not
281
+ have a custom estimator in the Python SDK), then set ``include_cls_metadata``
282
+ to ``False``.
283
+ * (dict[str, bool]) - This version should be used for tuners created via the factory
284
+ method ``HyperparameterTuner.create()``, to specify the flag for individual
285
+ estimators provided in the ``estimator_dict`` argument of the method. The keys
286
+ would be the same estimator names as in ``estimator_dict``. If one estimator
287
+ doesn't need the flag set, then no need to include it in the dictionary. If none
288
+ of the estimators need the flag set, then an empty dictionary ``{}`` must be used.
289
+
290
+ mini_batch_size: It can take one of the following two forms.
291
+
292
+ * (int) - Specify this argument only when estimator is a built-in estimator of an
293
+ Amazon algorithm. For other estimators, batch size should be specified in the
294
+ estimator.
295
+ * (dict[str, int]) - This version should be used for tuners created via the factory
296
+ method ``HyperparameterTuner.create()``, to specify the value for individual
297
+ estimators provided in the ``estimator_dict`` argument of the method. The keys
298
+ would be the same estimator names as in ``estimator_dict``. If one estimator
299
+ doesn't need the value set, then no need to include it in the dictionary. If
300
+ none of the estimators need the value set, then an empty dictionary ``{}``
301
+ must be used.
270
302
271
303
Returns:
272
304
dict: Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
273
305
"""
274
- train_config = training_base_config (tuner .estimator , inputs )
275
- hyperparameters = train_config .pop ("HyperParameters" , None )
276
- s3_operations = train_config .pop ("S3Operations" , None )
277
306
278
- if hyperparameters and len (hyperparameters ) > 0 :
279
- tuner .static_hyperparameters = {
280
- utils .to_str (k ): utils .to_str (v ) for (k , v ) in hyperparameters .items ()
281
- }
307
+ tuner ._prepare_job_name_for_tuning (job_name = job_name )
282
308
283
- if job_name is not None :
284
- tuner ._current_job_name = job_name
285
- else :
286
- base_name = tuner .base_tuning_job_name or utils .base_name_from_image (
287
- tuner .estimator .train_image ()
309
+ tune_config = {
310
+ "HyperParameterTuningJobName" : tuner ._current_job_name ,
311
+ "HyperParameterTuningJobConfig" : _extract_tuning_job_config (tuner ),
312
+ }
313
+
314
+ if tuner .estimator :
315
+ tune_config [
316
+ "TrainingJobDefinition"
317
+ ], s3_operations = _extract_training_config_from_estimator (
318
+ tuner , inputs , include_cls_metadata , mini_batch_size
288
319
)
289
- tuner ._current_job_name = utils .name_from_base (
290
- base_name , tuner .TUNING_JOB_NAME_MAX_LENGTH , True
320
+ else :
321
+ tune_config [
322
+ "TrainingJobDefinitions"
323
+ ], s3_operations = _extract_training_config_list_from_estimator_dict (
324
+ tuner , inputs , include_cls_metadata , mini_batch_size
291
325
)
292
326
293
- for hyperparameter_name in tuner . _hyperparameter_ranges . keys () :
294
- tuner . static_hyperparameters . pop ( hyperparameter_name , None )
327
+ if s3_operations :
328
+ tune_config [ "S3Operations" ] = s3_operations
295
329
296
- train_config ["StaticHyperParameters" ] = tuner .static_hyperparameters
330
+ if tuner .tags :
331
+ tune_config ["Tags" ] = tuner .tags
297
332
298
- tune_config = {
299
- "HyperParameterTuningJobName" : tuner ._current_job_name ,
300
- "HyperParameterTuningJobConfig" : {
301
- "Strategy" : tuner . strategy ,
302
- "HyperParameterTuningJobObjective" : {
303
- "Type" : tuner . objective_type ,
304
- "MetricName" : tuner . objective_metric_name ,
305
- },
306
- "ResourceLimits" : {
307
- "MaxNumberOfTrainingJobs " : tuner .max_jobs ,
308
- "MaxParallelTrainingJobs " : tuner . max_parallel_jobs ,
309
- } ,
310
- "ParameterRanges " : tuner .hyperparameter_ranges () ,
333
+ if tuner . warm_start_config :
334
+ tune_config [ "WarmStartConfig" ] = tuner .warm_start_config . to_input_req ()
335
+
336
+ return tune_config
337
+
338
+
339
+ def _extract_tuning_job_config ( tuner ):
340
+ """Extract tuning job config from a HyperparameterTuner"""
341
+ tuning_job_config = {
342
+ "Strategy " : tuner .strategy ,
343
+ "ResourceLimits " : {
344
+ "MaxNumberOfTrainingJobs" : tuner . max_jobs ,
345
+ "MaxParallelTrainingJobs " : tuner .max_parallel_jobs ,
311
346
},
312
- "TrainingJobDefinition " : train_config ,
347
+ "TrainingJobEarlyStoppingType " : tuner . early_stopping_type ,
313
348
}
314
349
315
- if tuner .metric_definitions is not None :
316
- tune_config ["TrainingJobDefinition" ]["AlgorithmSpecification" ][
350
+ if tuner .objective_metric_name :
351
+ tuning_job_config ["HyperParameterTuningJobObjective" ] = {
352
+ "Type" : tuner .objective_type ,
353
+ "MetricName" : tuner .objective_metric_name ,
354
+ }
355
+
356
+ parameter_ranges = tuner .hyperparameter_ranges ()
357
+ if parameter_ranges :
358
+ tuning_job_config ["ParameterRanges" ] = parameter_ranges
359
+
360
+ if tuner .training_instance_pools :
361
+ tuning_job_config ["TrainingJobInstancePools" ] = [
362
+ {
363
+ "InstanceType" : instance_type ,
364
+ "PoolSize" : tuner .training_instance_pools [instance_type ],
365
+ }
366
+ for instance_type in sorted (tuner .training_instance_pools .keys ())
367
+ ]
368
+
369
+ return tuning_job_config
370
+
371
+
372
+ def _extract_training_config_from_estimator (tuner , inputs , include_cls_metadata , mini_batch_size ):
373
+ """Extract training job config from a HyperparameterTuner that uses the ``estimator`` field"""
374
+ train_config = training_base_config (tuner .estimator , inputs , mini_batch_size )
375
+ train_config .pop ("HyperParameters" , None )
376
+
377
+ tuner ._prepare_static_hyperparameters_for_tuning (include_cls_metadata = include_cls_metadata )
378
+ train_config ["StaticHyperParameters" ] = tuner .static_hyperparameters
379
+
380
+ if tuner .metric_definitions :
381
+ train_config ["AlgorithmSpecification" ]["MetricDefinitions" ] = tuner .metric_definitions
382
+
383
+ s3_operations = train_config .pop ("S3Operations" , None )
384
+ return train_config , s3_operations
385
+
386
+
387
+ def _extract_training_config_list_from_estimator_dict (
388
+ tuner , inputs , include_cls_metadata , mini_batch_size
389
+ ):
390
+ """
391
+ Extract a list of training job configs from a HyperparameterTuner that uses the
392
+ ``estimator_dict`` field
393
+ """
394
+ estimator_names = sorted (tuner .estimator_dict .keys ())
395
+ tuner ._validate_dict_argument (name = "inputs" , value = inputs , allowed_keys = estimator_names )
396
+ tuner ._validate_dict_argument (
397
+ name = "include_cls_metadata" , value = include_cls_metadata , allowed_keys = estimator_names
398
+ )
399
+ tuner ._validate_dict_argument (
400
+ name = "mini_batch_size" , value = mini_batch_size , allowed_keys = estimator_names
401
+ )
402
+
403
+ train_config_dict = {}
404
+ for (estimator_name , estimator ) in tuner .estimator_dict .items ():
405
+ train_config_dict [estimator_name ] = training_base_config (
406
+ estimator = estimator ,
407
+ inputs = inputs .get (estimator_name ) if inputs else None ,
408
+ mini_batch_size = mini_batch_size .get (estimator_name ) if mini_batch_size else None ,
409
+ )
410
+
411
+ tuner ._prepare_static_hyperparameters_for_tuning (include_cls_metadata = include_cls_metadata )
412
+
413
+ train_config_list = []
414
+ s3_operations_list = []
415
+
416
+ for estimator_name in sorted (train_config_dict .keys ()):
417
+ train_config = train_config_dict [estimator_name ]
418
+ train_config .pop ("HyperParameters" , None )
419
+ train_config ["StaticHyperParameters" ] = tuner .static_hyperparameters_dict [estimator_name ]
420
+
421
+ train_config ["AlgorithmSpecification" ][
317
422
"MetricDefinitions"
318
- ] = tuner .metric_definitions
423
+ ] = tuner .metric_definitions_dict . get ( estimator_name )
319
424
320
- if tuner .tags is not None :
321
- tune_config ["Tags" ] = tuner .tags
425
+ train_config ["DefinitionName" ] = estimator_name
426
+ train_config ["TuningObjective" ] = {
427
+ "Type" : tuner .objective_type ,
428
+ "MetricName" : tuner .objective_metric_name_dict [estimator_name ],
429
+ }
430
+ train_config ["HyperParameterRanges" ] = tuner .hyperparameter_ranges_dict ()[estimator_name ]
322
431
323
- if s3_operations is not None :
324
- tune_config ["S3Operations" ] = s3_operations
432
+ s3_operations_list .append (train_config .pop ("S3Operations" , {}))
325
433
326
- return tune_config
434
+ train_config_list .append (train_config )
435
+
436
+ return train_config_list , _merge_s3_operations (s3_operations_list )
437
+
438
+
439
+ def _merge_s3_operations (s3_operations_list ):
440
+ """Merge a list of S3 operation dictionaries into one"""
441
+ s3_operations_merged = {}
442
+ for s3_operations in s3_operations_list :
443
+ for (key , operations ) in s3_operations .items ():
444
+ if key not in s3_operations_merged :
445
+ s3_operations_merged [key ] = []
446
+ for operation in operations :
447
+ if operation not in s3_operations_merged [key ]:
448
+ s3_operations_merged [key ].append (operation )
449
+ return s3_operations_merged
327
450
328
451
329
452
def update_submit_s3_uri (estimator , job_name ):
0 commit comments