Skip to content

Commit 146cd83

Browse files
author
Anton Repushko
committed
feature: add support of the intelligent stopping in the tuner
1 parent e2f3888 commit 146cd83

File tree

4 files changed

+188
-1
lines changed

4 files changed

+188
-1
lines changed

src/sagemaker/session.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,7 +2189,9 @@ def tune( # noqa: C901
21892189
stop_condition,
21902190
tags,
21912191
warm_start_config,
2192+
max_runtime_in_seconds=None,
21922193
strategy_config=None,
2194+
completion_criteria_config=None,
21932195
enable_network_isolation=False,
21942196
image_uri=None,
21952197
algorithm_arn=None,
@@ -2255,6 +2257,10 @@ def tune( # noqa: C901
22552257
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
22562258
warm_start_config (dict): Configuration defining the type of warm start and
22572259
other required configurations.
2260+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
2261+
that a training job launched by a hyperparameter tuning job can run.
2262+
completion_criteria_config (sagemaker.tuner.TuningJobCompletionCriteriaConfig): A
2263+
configuration for the completion criteria.
22582264
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
22592265
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
22602266
attempted. If set to 'Auto', early stopping of some training jobs may happen, but
@@ -2294,12 +2300,14 @@ def tune( # noqa: C901
22942300
strategy=strategy,
22952301
max_jobs=max_jobs,
22962302
max_parallel_jobs=max_parallel_jobs,
2303+
max_runtime_in_seconds=max_runtime_in_seconds,
22972304
objective_type=objective_type,
22982305
objective_metric_name=objective_metric_name,
22992306
parameter_ranges=parameter_ranges,
23002307
early_stopping_type=early_stopping_type,
23012308
random_seed=random_seed,
23022309
strategy_config=strategy_config,
2310+
completion_criteria_config=completion_criteria_config,
23032311
),
23042312
"TrainingJobDefinition": self._map_training_config(
23052313
static_hyperparameters=static_hyperparameters,
@@ -2452,12 +2460,14 @@ def _map_tuning_config(
24522460
strategy,
24532461
max_jobs,
24542462
max_parallel_jobs,
2463+
max_runtime_in_seconds=None,
24552464
early_stopping_type="Off",
24562465
objective_type=None,
24572466
objective_metric_name=None,
24582467
parameter_ranges=None,
24592468
random_seed=None,
24602469
strategy_config=None,
2470+
completion_criteria_config=None,
24612471
):
24622472
"""Construct tuning job configuration dictionary.
24632473
@@ -2466,6 +2476,8 @@ def _map_tuning_config(
24662476
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
24672477
tuning job.
24682478
max_parallel_jobs (int): Maximum number of parallel training jobs to start.
2479+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
2480+
that a training job launched by a hyperparameter tuning job can run.
24692481
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
24702482
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
24712483
attempted. If set to 'Auto', early stopping of some training jobs may happen,
@@ -2480,6 +2492,8 @@ def _map_tuning_config(
24802492
produce more consistent configurations for the same tuning job.
24812493
strategy_config (dict): A configuration for the hyperparameter tuning job optimisation
24822494
strategy.
2495+
completion_criteria_config (dict): A configuration
2496+
for the completion criteria.
24832497
24842498
Returns:
24852499
A dictionary of tuning job configuration. For format details, please refer to
@@ -2496,6 +2510,9 @@ def _map_tuning_config(
24962510
"TrainingJobEarlyStoppingType": early_stopping_type,
24972511
}
24982512

2513+
if max_runtime_in_seconds is not None:
2514+
tuning_config["ResourceLimits"]["MaxRuntimeInSeconds"] = max_runtime_in_seconds
2515+
24992516
if random_seed is not None:
25002517
tuning_config["RandomSeed"] = random_seed
25012518

@@ -2508,6 +2525,9 @@ def _map_tuning_config(
25082525

25092526
if strategy_config is not None:
25102527
tuning_config["StrategyConfig"] = strategy_config
2528+
2529+
if completion_criteria_config is not None:
2530+
tuning_config["TuningJobCompletionCriteria"] = completion_criteria_config
25112531
return tuning_config
25122532

25132533
@classmethod

src/sagemaker/tuner.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@
7272
HYPERBAND_MIN_RESOURCE = "MinResource"
7373
HYPERBAND_MAX_RESOURCE = "MaxResource"
7474
GRID_SEARCH = "GridSearch"
75+
MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING = "MaxNumberOfTrainingJobsNotImproving"
76+
BEST_OBJECTIVE_NOT_IMPROVING = "BestObjectiveNotImproving"
77+
CONVERGENCE_DETECTED = "ConvergenceDetected"
78+
COMPLETE_ON_CONVERGENCE_DETECTED = "CompleteOnConvergence"
79+
TARGET_OBJECTIVE_METRIC_VALUE = "TargetObjectiveMetricValue"
80+
MAX_RUNTIME_IN_SECONDS = "MaxRuntimeInSeconds"
7581

7682
logger = logging.getLogger(__name__)
7783

@@ -383,6 +389,116 @@ def to_input_req(self):
383389
}
384390

385391

392+
class TuningJobCompletionCriteriaConfig(object):
393+
"""The configuration for a job completion criteria."""
394+
395+
def __init__(
396+
self,
397+
max_number_of_training_jobs_not_improving: int = None,
398+
complete_on_convergence: bool = None,
399+
target_objective_metric_value: float = None,
400+
):
401+
"""Creates a ``TuningJobCompletionCriteriaConfig`` with provided criteria.
402+
403+
Args:
404+
max_number_of_training_jobs_not_improving (int): The number of training jobs that do not
405+
improve the best objective after which tuning job will stop.
406+
complete_on_convergence (bool): A flag to stop your hyperparameter tuning job if
407+
automatic model tuning (AMT) has detected that your model has converged as evaluated
408+
against your objective function.
409+
target_objective_metric_value (float): The value of the objective metric.
410+
"""
411+
412+
self.max_number_of_training_jobs_not_improving = max_number_of_training_jobs_not_improving
413+
self.complete_on_convergence = complete_on_convergence
414+
self.target_objective_metric_value = target_objective_metric_value
415+
416+
@classmethod
417+
def from_job_desc(cls, completion_criteria_config):
418+
"""Creates a ``TuningJobCompletionCriteriaConfig`` from a configuration response.
419+
420+
This is the completion criteria configuration from the DescribeTuningJob response.
421+
Args:
422+
completion_criteria_config (dict): The expected format of the
423+
``completion_criteria_config`` contains three first-class fields
424+
425+
Returns:
426+
sagemaker.tuner.TuningJobCompletionCriteriaConfig: De-serialized instance of
427+
TuningJobCompletionCriteriaConfig containing the completion criteria.
428+
"""
429+
complete_on_convergence = None
430+
if CONVERGENCE_DETECTED in completion_criteria_config:
431+
if completion_criteria_config[CONVERGENCE_DETECTED][COMPLETE_ON_CONVERGENCE_DETECTED]:
432+
complete_on_convergence = bool(
433+
completion_criteria_config[CONVERGENCE_DETECTED][
434+
COMPLETE_ON_CONVERGENCE_DETECTED
435+
]
436+
== "Enabled"
437+
)
438+
439+
max_number_of_training_jobs_not_improving = None
440+
if BEST_OBJECTIVE_NOT_IMPROVING in completion_criteria_config:
441+
if completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING][
442+
MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING
443+
]:
444+
max_number_of_training_jobs_not_improving = completion_criteria_config[
445+
BEST_OBJECTIVE_NOT_IMPROVING
446+
][MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING]
447+
448+
target_objective_metric_value = None
449+
if TARGET_OBJECTIVE_METRIC_VALUE in completion_criteria_config:
450+
target_objective_metric_value = completion_criteria_config[
451+
TARGET_OBJECTIVE_METRIC_VALUE
452+
]
453+
454+
return cls(
455+
max_number_of_training_jobs_not_improving=max_number_of_training_jobs_not_improving,
456+
complete_on_convergence=complete_on_convergence,
457+
target_objective_metric_value=target_objective_metric_value,
458+
)
459+
460+
def to_input_req(self):
461+
"""Converts the ``self`` instance to the desired input request format.
462+
463+
Examples:
464+
>>> completion_criteria_config = TuningJobCompletionCriteriaConfig(
465+
max_number_of_training_jobs_not_improving=5
466+
complete_on_convergence = True,
467+
target_objective_metric_value = 0.42
468+
)
469+
>>> completion_criteria_config.to_input_req()
470+
{
471+
"BestObjectiveNotImproving": {
472+
"MaxNumberOfTrainingJobsNotImproving":5
473+
},
474+
"ConvergenceDetected": {
475+
"CompleteOnConvergence": "Enabled",
476+
},
477+
"TargetObjectiveMetricValue": 0.42
478+
}
479+
480+
Returns:
481+
dict: Containing the completion criteria configurations.
482+
"""
483+
completion_criteria_config = {}
484+
if self.max_number_of_training_jobs_not_improving is not None:
485+
completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING][
486+
MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING
487+
] = self.max_number_of_training_jobs_not_improving
488+
489+
if self.target_objective_metric_value is not None:
490+
completion_criteria_config[
491+
TARGET_OBJECTIVE_METRIC_VALUE
492+
] = self.target_objective_metric_value
493+
494+
if self.complete_on_convergence is not None:
495+
completion_criteria_config[CONVERGENCE_DETECTED][COMPLETE_ON_CONVERGENCE_DETECTED] = (
496+
"Enabled" if self.complete_on_convergence else "Disabled"
497+
)
498+
499+
return completion_criteria_config
500+
501+
386502
class HyperparameterTuner(object):
387503
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388504
@@ -407,10 +523,12 @@ def __init__(
407523
objective_type: Union[str, PipelineVariable] = "Maximize",
408524
max_jobs: Union[int, PipelineVariable] = None,
409525
max_parallel_jobs: Union[int, PipelineVariable] = 1,
526+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
410527
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
411528
base_tuning_job_name: Optional[str] = None,
412529
warm_start_config: Optional[WarmStartConfig] = None,
413530
strategy_config: Optional[StrategyConfig] = None,
531+
completion_criteria_config: Optional[TuningJobCompletionCriteriaConfig] = None,
414532
early_stopping_type: Union[str, PipelineVariable] = "Off",
415533
estimator_name: Optional[str] = None,
416534
random_seed: Optional[int] = None,
@@ -450,6 +568,8 @@ def __init__(
450568
strategy and the default value is 1 for all others strategies (default: None).
451569
max_parallel_jobs (int or PipelineVariable): Maximum number of parallel training jobs to
452570
start (default: 1).
571+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
572+
that a training job launched by a hyperparameter tuning job can run.
453573
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
454574
labeling the tuning job (default: None). For more, see
455575
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
@@ -463,6 +583,8 @@ def __init__(
463583
configuration defining the nature of warm start tuning job.
464584
strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter"
465585
tuning job optimisation strategy.
586+
completion_criteria_config (sagemaker.tuner.TuningJobCompletionCriteriaConfig): A
587+
configuration for the completion criteria.
466588
early_stopping_type (str or PipelineVariable): Specifies whether early stopping is
467589
enabled for the job. Can be either 'Auto' or 'Off' (default:
468590
'Off'). If set to 'Off', early stopping will not be attempted.
@@ -505,6 +627,7 @@ def __init__(
505627

506628
self.strategy = strategy
507629
self.strategy_config = strategy_config
630+
self.completion_criteria_config = completion_criteria_config
508631
self.objective_type = objective_type
509632
# For the GridSearch strategy we expect the max_jobs equals None and recalculate it later.
510633
# For all other strategies for the backward compatibility we keep
@@ -513,6 +636,7 @@ def __init__(
513636
if max_jobs is None and strategy is not GRID_SEARCH:
514637
self.max_jobs = 1
515638
self.max_parallel_jobs = max_parallel_jobs
639+
self.max_runtime_in_seconds = max_runtime_in_seconds
516640

517641
self.tags = tags
518642
self.base_tuning_job_name = base_tuning_job_name
@@ -1227,6 +1351,16 @@ def _prepare_init_params_from_job_description(cls, job_details):
12271351
"base_tuning_job_name": base_from_name(job_details["HyperParameterTuningJobName"]),
12281352
}
12291353

1354+
if "TuningJobCompletionCriteria" in tuning_config:
1355+
params["completion_criteria_config"] = TuningJobCompletionCriteriaConfig.from_job_desc(
1356+
tuning_config["TuningJobCompletionCriteria"]
1357+
)
1358+
1359+
if MAX_RUNTIME_IN_SECONDS in tuning_config["ResourceLimits"]:
1360+
params["max_runtime_in_seconds"] = tuning_config["ResourceLimits"][
1361+
MAX_RUNTIME_IN_SECONDS
1362+
]
1363+
12301364
if "RandomSeed" in tuning_config:
12311365
params["random_seed"] = tuning_config["RandomSeed"]
12321366

@@ -1484,9 +1618,11 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
14841618
hyperparameter_ranges=self._hyperparameter_ranges,
14851619
strategy=self.strategy,
14861620
strategy_config=self.strategy_config,
1621+
completion_criteria_config=self.completion_criteria_config,
14871622
objective_type=self.objective_type,
14881623
max_jobs=self.max_jobs,
14891624
max_parallel_jobs=self.max_parallel_jobs,
1625+
max_runtime_in_seconds=self.max_runtime_in_seconds,
14901626
warm_start_config=WarmStartConfig(
14911627
warm_start_type=warm_start_type, parents=all_parents
14921628
),
@@ -1512,9 +1648,11 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
15121648
metric_definitions_dict=self.metric_definitions_dict,
15131649
strategy=self.strategy,
15141650
strategy_config=self.strategy_config,
1651+
completion_criteria_config=self.completion_criteria_config,
15151652
objective_type=self.objective_type,
15161653
max_jobs=self.max_jobs,
15171654
max_parallel_jobs=self.max_parallel_jobs,
1655+
max_runtime_in_seconds=self.max_runtime_in_seconds,
15181656
warm_start_config=WarmStartConfig(warm_start_type=warm_start_type, parents=all_parents),
15191657
early_stopping_type=self.early_stopping_type,
15201658
random_seed=self.random_seed,
@@ -1530,9 +1668,11 @@ def create(
15301668
base_tuning_job_name=None,
15311669
strategy="Bayesian",
15321670
strategy_config=None,
1671+
completion_criteria_config=None,
15331672
objective_type="Maximize",
15341673
max_jobs=None,
15351674
max_parallel_jobs=1,
1675+
max_runtime_in_seconds=None,
15361676
tags=None,
15371677
warm_start_config=None,
15381678
early_stopping_type="Off",
@@ -1581,13 +1721,16 @@ def create(
15811721
(default: 'Bayesian').
15821722
strategy_config (dict): The configuration for a training job launched by a
15831723
hyperparameter tuning job.
1724+
completion_criteria_config (dict): The configuration for tuning job completion criteria.
15841725
objective_type (str): The type of the objective metric for evaluating training jobs.
15851726
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
15861727
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
15871728
tuning job. The default value is unspecified fot the GridSearch strategy
15881729
and the value is 1 for all others strategies (default: None).
15891730
max_parallel_jobs (int): Maximum number of parallel training jobs to start
15901731
(default: 1).
1732+
max_runtime_in_seconds (int): The maximum time in seconds
1733+
that a training job launched by a hyperparameter tuning job can run.
15911734
tags (list[dict]): List of tags for labeling the tuning job (default: None). For more,
15921735
see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
15931736
warm_start_config (sagemaker.tuner.WarmStartConfig): A ``WarmStartConfig`` object that
@@ -1632,9 +1775,11 @@ def create(
16321775
metric_definitions=metric_definitions,
16331776
strategy=strategy,
16341777
strategy_config=strategy_config,
1778+
completion_criteria_config=completion_criteria_config,
16351779
objective_type=objective_type,
16361780
max_jobs=max_jobs,
16371781
max_parallel_jobs=max_parallel_jobs,
1782+
max_runtime_in_seconds=max_runtime_in_seconds,
16381783
tags=tags,
16391784
warm_start_config=warm_start_config,
16401785
early_stopping_type=early_stopping_type,
@@ -1790,6 +1935,9 @@ def _get_tuner_args(cls, tuner, inputs):
17901935
"early_stopping_type": tuner.early_stopping_type,
17911936
}
17921937

1938+
if tuner.max_runtime_in_seconds is not None:
1939+
tuning_config["max_runtime_in_seconds"] = tuner.max_runtime_in_seconds
1940+
17931941
if tuner.random_seed is not None:
17941942
tuning_config["random_seed"] = tuner.random_seed
17951943

@@ -1804,6 +1952,11 @@ def _get_tuner_args(cls, tuner, inputs):
18041952
if parameter_ranges is not None:
18051953
tuning_config["parameter_ranges"] = parameter_ranges
18061954

1955+
if tuner.completion_criteria_config is not None:
1956+
tuning_config[
1957+
"completion_criteria_config"
1958+
] = tuner.completion_criteria_config.to_input_req()
1959+
18071960
tuner_args = {
18081961
"job_name": tuner._current_job_name,
18091962
"tuning_config": tuning_config,

tests/unit/test_tuner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,12 +543,17 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session
543543
assert tuner.objective_metric_name == OBJECTIVE_METRIC_NAME
544544
assert tuner.max_jobs == 1
545545
assert tuner.max_parallel_jobs == 1
546+
assert tuner.max_runtime_in_seconds == 1
546547
assert tuner.metric_definitions == METRIC_DEFINITIONS
547548
assert tuner.strategy == "Bayesian"
548549
assert tuner.objective_type == "Minimize"
549550
assert tuner.early_stopping_type == "Off"
550551
assert tuner.random_seed == 0
551552

553+
assert tuner.completion_criteria_config.complete_on_convergence is True
554+
assert tuner.completion_criteria_config.target_objective_metric_value == 0.42
555+
assert tuner.completion_criteria_config.max_number_of_training_jobs_not_improving == 5
556+
552557
assert isinstance(tuner.estimator, PCA)
553558
assert tuner.estimator.role == ROLE
554559
assert tuner.estimator.instance_count == 1

0 commit comments

Comments
 (0)