Skip to content

Commit ed43726

Browse files
pengk19knakad
authored andcommitted
Multi-Algorithm Hyperparameter Tuning Support
1 parent f8ac704 commit ed43726

File tree

12 files changed

+3071
-407
lines changed

12 files changed

+3071
-407
lines changed

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
[flake8]
22
application_import_names = sagemaker, tests
33
import-order-style = google
4+
per-file-ignores =
5+
tests/unit/test_tuner.py: F405

src/sagemaker/analytics.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def reshape(training_summary):
142142
out["TrainingEndTime"] = end_time
143143
if start_time and end_time:
144144
out["TrainingElapsedTimeSeconds"] = (end_time - start_time).total_seconds()
145+
if "TrainingJobDefinitionName" in training_summary:
146+
out["TrainingJobDefinitionName"] = training_summary["TrainingJobDefinitionName"]
145147
return out
146148

147149
# Run that helper over all the summaries.
@@ -152,11 +154,59 @@ def reshape(training_summary):
152154
def tuning_ranges(self):
153155
"""A dictionary describing the ranges of all tuned hyperparameters. The
154156
keys are the names of the hyperparameter, and the values are the ranges.
157+
158+
The output can take one of two forms:
159+
160+
* If the 'TrainingJobDefinition' field is present in the job description, the output
161+
is a dictionary constructed from 'ParameterRanges' in
162+
'HyperParameterTuningJobConfig' of the job description. The keys are the
163+
parameter names, while the values are the parameter ranges.
164+
Example:
165+
>>> {
166+
>>> "eta": {"MaxValue": "1", "MinValue": "0", "Name": "eta"},
167+
>>> "gamma": {"MaxValue": "10", "MinValue": "0", "Name": "gamma"},
168+
>>> "iterations": {"MaxValue": "100", "MinValue": "50", "Name": "iterations"},
169+
>>> "num_layers": {"MaxValue": "30", "MinValue": "5", "Name": "num_layers"},
170+
>>> }
171+
* If the 'TrainingJobDefinitions' field (list) is present in the job description,
172+
the output is a dictionary with keys as the 'DefinitionName' values from
173+
all items in 'TrainingJobDefinitions', and each value would be a dictionary
174+
constructed from 'HyperParameterRanges' in each item in 'TrainingJobDefinitions'
175+
in the same format as above
176+
Example:
177+
>>> {
178+
>>> "estimator_1": {
179+
>>> "eta": {"MaxValue": "1", "MinValue": "0", "Name": "eta"},
180+
>>> "gamma": {"MaxValue": "10", "MinValue": "0", "Name": "gamma"},
181+
>>> },
182+
>>> "estimator_2": {
183+
>>> "framework": {"Values": ["TF", "MXNet"], "Name": "framework"},
184+
>>> "gamma": {"MaxValue": "1.0", "MinValue": "0.2", "Name": "gamma"}
185+
>>> }
186+
>>> }
187+
188+
For more details about the 'TrainingJobDefinition' and 'TrainingJobDefinitions' fields
189+
in job description, see
190+
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
155191
"""
192+
description = self.description()
193+
194+
if "TrainingJobDefinition" in description:
195+
return self._prepare_parameter_ranges(
196+
description["HyperParameterTuningJobConfig"]["ParameterRanges"]
197+
)
198+
199+
return {
200+
training_job_definition["DefinitionName"]: self._prepare_parameter_ranges(
201+
training_job_definition["HyperParameterRanges"]
202+
)
203+
for training_job_definition in description["TrainingJobDefinitions"]
204+
}
205+
206+
def _prepare_parameter_ranges(self, parameter_ranges):
207+
"""Convert parameter ranges a dictionary using the parameter range names as the keys"""
156208
out = {}
157-
for _, ranges in self.description()["HyperParameterTuningJobConfig"][
158-
"ParameterRanges"
159-
].items():
209+
for _, ranges in parameter_ranges.items():
160210
for param in ranges:
161211
out[param["Name"]] = param
162212
return out

src/sagemaker/session.py

Lines changed: 308 additions & 40 deletions
Large diffs are not rendered by default.

src/sagemaker/tensorflow/predictor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import google.protobuf.json_format as json_format
1919
from google.protobuf.message import DecodeError
2020
from protobuf_to_dict import protobuf_to_dict
21-
2221
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_OCTET_STREAM, CONTENT_TYPE_CSV
2322
from sagemaker.predictor import json_serializer, csv_serializer
2423

src/sagemaker/tuner.py

Lines changed: 795 additions & 126 deletions
Large diffs are not rendered by default.

src/sagemaker/workflow/airflow.py

Lines changed: 164 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None)
239239
return train_config
240240

241241

242-
def tuning_config(tuner, inputs, job_name=None):
243-
"""Export Airflow tuning config from an estimator
242+
def tuning_config(tuner, inputs, job_name=None, include_cls_metadata=False, mini_batch_size=None):
243+
"""Export Airflow tuning config from a HyperparameterTuner
244244
245245
Args:
246246
tuner (sagemaker.tuner.HyperparameterTuner): The tuner to export tuning
@@ -266,64 +266,187 @@ def tuning_config(tuner, inputs, job_name=None):
266266
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
267267
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
268268
where each instance is a different channel of training data.
269+
270+
* (dict[str, one the forms above]): Required by only tuners created via
271+
the factory method ``HyperparameterTuner.create()``. The keys should be the
272+
same estimator names as keys for the ``estimator_dict`` argument of the
273+
``HyperparameterTuner.create()`` method.
269274
job_name (str): Specify a tuning job name if needed.
275+
include_cls_metadata: It can take one of the following two forms.
276+
277+
* (bool) - Whether or not the hyperparameter tuning job should include information
278+
about the estimator class (default: False). This information is passed as a
279+
hyperparameter, so if the algorithm you are using cannot handle unknown
280+
hyperparameters (e.g. an Amazon SageMaker built-in algorithm that does not
281+
have a custom estimator in the Python SDK), then set ``include_cls_metadata``
282+
to ``False``.
283+
* (dict[str, bool]) - This version should be used for tuners created via the factory
284+
method ``HyperparameterTuner.create()``, to specify the flag for individual
285+
estimators provided in the ``estimator_dict`` argument of the method. The keys
286+
would be the same estimator names as in ``estimator_dict``. If one estimator
287+
doesn't need the flag set, then no need to include it in the dictionary. If none
288+
of the estimators need the flag set, then an empty dictionary ``{}`` must be used.
289+
290+
mini_batch_size: It can take one of the following two forms.
291+
292+
* (int) - Specify this argument only when estimator is a built-in estimator of an
293+
Amazon algorithm. For other estimators, batch size should be specified in the
294+
estimator.
295+
* (dict[str, int]) - This version should be used for tuners created via the factory
296+
method ``HyperparameterTuner.create()``, to specify the value for individual
297+
estimators provided in the ``estimator_dict`` argument of the method. The keys
298+
would be the same estimator names as in ``estimator_dict``. If one estimator
299+
doesn't need the value set, then no need to include it in the dictionary. If
300+
none of the estimators need the value set, then an empty dictionary ``{}``
301+
must be used.
270302
271303
Returns:
272304
dict: Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
273305
"""
274-
train_config = training_base_config(tuner.estimator, inputs)
275-
hyperparameters = train_config.pop("HyperParameters", None)
276-
s3_operations = train_config.pop("S3Operations", None)
277306

278-
if hyperparameters and len(hyperparameters) > 0:
279-
tuner.static_hyperparameters = {
280-
utils.to_str(k): utils.to_str(v) for (k, v) in hyperparameters.items()
281-
}
307+
tuner._prepare_job_name_for_tuning(job_name=job_name)
282308

283-
if job_name is not None:
284-
tuner._current_job_name = job_name
285-
else:
286-
base_name = tuner.base_tuning_job_name or utils.base_name_from_image(
287-
tuner.estimator.train_image()
309+
tune_config = {
310+
"HyperParameterTuningJobName": tuner._current_job_name,
311+
"HyperParameterTuningJobConfig": _extract_tuning_job_config(tuner),
312+
}
313+
314+
if tuner.estimator:
315+
tune_config[
316+
"TrainingJobDefinition"
317+
], s3_operations = _extract_training_config_from_estimator(
318+
tuner, inputs, include_cls_metadata, mini_batch_size
288319
)
289-
tuner._current_job_name = utils.name_from_base(
290-
base_name, tuner.TUNING_JOB_NAME_MAX_LENGTH, True
320+
else:
321+
tune_config[
322+
"TrainingJobDefinitions"
323+
], s3_operations = _extract_training_config_list_from_estimator_dict(
324+
tuner, inputs, include_cls_metadata, mini_batch_size
291325
)
292326

293-
for hyperparameter_name in tuner._hyperparameter_ranges.keys():
294-
tuner.static_hyperparameters.pop(hyperparameter_name, None)
327+
if s3_operations:
328+
tune_config["S3Operations"] = s3_operations
295329

296-
train_config["StaticHyperParameters"] = tuner.static_hyperparameters
330+
if tuner.tags:
331+
tune_config["Tags"] = tuner.tags
297332

298-
tune_config = {
299-
"HyperParameterTuningJobName": tuner._current_job_name,
300-
"HyperParameterTuningJobConfig": {
301-
"Strategy": tuner.strategy,
302-
"HyperParameterTuningJobObjective": {
303-
"Type": tuner.objective_type,
304-
"MetricName": tuner.objective_metric_name,
305-
},
306-
"ResourceLimits": {
307-
"MaxNumberOfTrainingJobs": tuner.max_jobs,
308-
"MaxParallelTrainingJobs": tuner.max_parallel_jobs,
309-
},
310-
"ParameterRanges": tuner.hyperparameter_ranges(),
333+
if tuner.warm_start_config:
334+
tune_config["WarmStartConfig"] = tuner.warm_start_config.to_input_req()
335+
336+
return tune_config
337+
338+
339+
def _extract_tuning_job_config(tuner):
340+
"""Extract tuning job config from a HyperparameterTuner"""
341+
tuning_job_config = {
342+
"Strategy": tuner.strategy,
343+
"ResourceLimits": {
344+
"MaxNumberOfTrainingJobs": tuner.max_jobs,
345+
"MaxParallelTrainingJobs": tuner.max_parallel_jobs,
311346
},
312-
"TrainingJobDefinition": train_config,
347+
"TrainingJobEarlyStoppingType": tuner.early_stopping_type,
313348
}
314349

315-
if tuner.metric_definitions is not None:
316-
tune_config["TrainingJobDefinition"]["AlgorithmSpecification"][
350+
if tuner.objective_metric_name:
351+
tuning_job_config["HyperParameterTuningJobObjective"] = {
352+
"Type": tuner.objective_type,
353+
"MetricName": tuner.objective_metric_name,
354+
}
355+
356+
parameter_ranges = tuner.hyperparameter_ranges()
357+
if parameter_ranges:
358+
tuning_job_config["ParameterRanges"] = parameter_ranges
359+
360+
if tuner.training_instance_pools:
361+
tuning_job_config["TrainingJobInstancePools"] = [
362+
{
363+
"InstanceType": instance_type,
364+
"PoolSize": tuner.training_instance_pools[instance_type],
365+
}
366+
for instance_type in sorted(tuner.training_instance_pools.keys())
367+
]
368+
369+
return tuning_job_config
370+
371+
372+
def _extract_training_config_from_estimator(tuner, inputs, include_cls_metadata, mini_batch_size):
373+
"""Extract training job config from a HyperparameterTuner that uses the ``estimator`` field"""
374+
train_config = training_base_config(tuner.estimator, inputs, mini_batch_size)
375+
train_config.pop("HyperParameters", None)
376+
377+
tuner._prepare_static_hyperparameters_for_tuning(include_cls_metadata=include_cls_metadata)
378+
train_config["StaticHyperParameters"] = tuner.static_hyperparameters
379+
380+
if tuner.metric_definitions:
381+
train_config["AlgorithmSpecification"]["MetricDefinitions"] = tuner.metric_definitions
382+
383+
s3_operations = train_config.pop("S3Operations", None)
384+
return train_config, s3_operations
385+
386+
387+
def _extract_training_config_list_from_estimator_dict(
388+
tuner, inputs, include_cls_metadata, mini_batch_size
389+
):
390+
"""
391+
Extract a list of training job configs from a HyperparameterTuner that uses the
392+
``estimator_dict`` field
393+
"""
394+
estimator_names = sorted(tuner.estimator_dict.keys())
395+
tuner._validate_dict_argument(name="inputs", value=inputs, allowed_keys=estimator_names)
396+
tuner._validate_dict_argument(
397+
name="include_cls_metadata", value=include_cls_metadata, allowed_keys=estimator_names
398+
)
399+
tuner._validate_dict_argument(
400+
name="mini_batch_size", value=mini_batch_size, allowed_keys=estimator_names
401+
)
402+
403+
train_config_dict = {}
404+
for (estimator_name, estimator) in tuner.estimator_dict.items():
405+
train_config_dict[estimator_name] = training_base_config(
406+
estimator=estimator,
407+
inputs=inputs.get(estimator_name) if inputs else None,
408+
mini_batch_size=mini_batch_size.get(estimator_name) if mini_batch_size else None,
409+
)
410+
411+
tuner._prepare_static_hyperparameters_for_tuning(include_cls_metadata=include_cls_metadata)
412+
413+
train_config_list = []
414+
s3_operations_list = []
415+
416+
for estimator_name in sorted(train_config_dict.keys()):
417+
train_config = train_config_dict[estimator_name]
418+
train_config.pop("HyperParameters", None)
419+
train_config["StaticHyperParameters"] = tuner.static_hyperparameters_dict[estimator_name]
420+
421+
train_config["AlgorithmSpecification"][
317422
"MetricDefinitions"
318-
] = tuner.metric_definitions
423+
] = tuner.metric_definitions_dict.get(estimator_name)
319424

320-
if tuner.tags is not None:
321-
tune_config["Tags"] = tuner.tags
425+
train_config["DefinitionName"] = estimator_name
426+
train_config["TuningObjective"] = {
427+
"Type": tuner.objective_type,
428+
"MetricName": tuner.objective_metric_name_dict[estimator_name],
429+
}
430+
train_config["HyperParameterRanges"] = tuner.hyperparameter_ranges_dict()[estimator_name]
322431

323-
if s3_operations is not None:
324-
tune_config["S3Operations"] = s3_operations
432+
s3_operations_list.append(train_config.pop("S3Operations", {}))
325433

326-
return tune_config
434+
train_config_list.append(train_config)
435+
436+
return train_config_list, _merge_s3_operations(s3_operations_list)
437+
438+
439+
def _merge_s3_operations(s3_operations_list):
440+
"""Merge a list of S3 operation dictionaries into one"""
441+
s3_operations_merged = {}
442+
for s3_operations in s3_operations_list:
443+
for (key, operations) in s3_operations.items():
444+
if key not in s3_operations_merged:
445+
s3_operations_merged[key] = []
446+
for operation in operations:
447+
if operation not in s3_operations_merged[key]:
448+
s3_operations_merged[key].append(operation)
449+
return s3_operations_merged
327450

328451

329452
def update_submit_s3_uri(estimator, job_name):

0 commit comments

Comments
 (0)