@@ -383,6 +383,37 @@ def to_input_req(self):
383
383
}
384
384
385
385
386
+ class InstanceConfig:
387
+ """Instance configuration for training jobs started by hyperparameter tuning.
388
+
389
+ Contains the configuration(s) for one or more resources for processing hyperparameter jobs.
390
+ These resources include compute instances and storage volumes to use in model training jobs
391
+ launched by hyperparameter tuning jobs.
392
+ """
393
+
394
+ def __init__(
395
+ self,
396
+ instance_count: Union[int, PipelineVariable] = None,
397
+ instance_type: Union[str, PipelineVariable] = None,
398
+ volume_size: Union[int, PipelineVariable] = 30,
399
+ ):
400
+ """Creates a ``InstanceConfig`` instance.
401
+
402
+ It takes instance configuration information for training
403
+ jobs that are created as the result of a hyperparameter tuning job.
404
+ Args:
405
+ * instance_count (str or PipelineVariable): The number of compute instances of type
406
+ InstanceType to use. For distributed training, select a value greater than 1.
407
+ * instance_type (str or PipelineVariable):
408
+ The instance type used to run hyperparameter optimization tuning jobs.
409
+ * volume_size (int or PipelineVariable): The volume size in GB of the data to be
410
+ processed for hyperparameter optimization
411
+ """
412
+ self.instance_count = instance_count
413
+ self.instance_type = instance_type
414
+ self.volume_size = volume_size
415
+
416
+
386
417
class HyperparameterTuner(object):
387
418
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388
419
@@ -419,7 +450,6 @@ def __init__(
419
450
420
451
It takes an estimator to obtain configuration information for training
421
452
jobs that are created as the result of a hyperparameter tuning job.
422
-
423
453
Args:
424
454
estimator (sagemaker.estimator.EstimatorBase): An estimator object
425
455
that has been initialized with the desired configuration. There
@@ -489,6 +519,7 @@ def __init__(
489
519
self.metric_definitions_dict = (
490
520
{estimator_name: metric_definitions} if metric_definitions is not None else {}
491
521
)
522
+ self.instance_configs_dict = {}
492
523
self.static_hyperparameters = None
493
524
else:
494
525
self.estimator = estimator
@@ -500,6 +531,7 @@ def __init__(
500
531
self._hyperparameter_ranges_dict = None
501
532
self.metric_definitions_dict = None
502
533
self.static_hyperparameters_dict = None
534
+ self.instance_configs_dict = None
503
535
504
536
self._validate_parameter_ranges(estimator, hyperparameter_ranges)
505
537
@@ -521,6 +553,30 @@ def __init__(
521
553
self.warm_start_config = warm_start_config
522
554
self.early_stopping_type = early_stopping_type
523
555
self.random_seed = random_seed
556
+ self.instance_configs = None
557
+
558
+ def override_resource_config(
559
+ self, instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]]
560
+ ):
561
+ """Override the instance configuration of the estimators used by the tuner.
562
+
563
+ Args:
564
+ instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
565
+ The InstanceConfigs to use as an override for the instance configuration
566
+ of the estimator. ``None`` will remove the override.
567
+ """
568
+ if isinstance(instance_configs, dict):
569
+ self._validate_dict_argument(
570
+ name="instance_configs",
571
+ value=instance_configs,
572
+ allowed_keys=list(self.estimator_dict.keys()),
573
+ )
574
+ self.instance_configs_dict = instance_configs
575
+ else:
576
+ self.instance_configs = instance_configs
577
+ if self.estimator_dict and self.estimator_dict.keys():
578
+ estimator_names = list(self.estimator_dict.keys())
579
+ self.instance_configs_dict = {estimator_names[0]: instance_configs}
524
580
525
581
def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False):
526
582
"""Prepare the tuner instance for tuning (fit)."""
@@ -1817,6 +1873,7 @@ def _get_tuner_args(cls, tuner, inputs):
1817
1873
estimator=tuner.estimator,
1818
1874
static_hyperparameters=tuner.static_hyperparameters,
1819
1875
metric_definitions=tuner.metric_definitions,
1876
+ instance_configs=tuner.instance_configs,
1820
1877
)
1821
1878
1822
1879
if tuner.estimator_dict is not None:
@@ -1830,12 +1887,49 @@ def _get_tuner_args(cls, tuner, inputs):
1830
1887
tuner.objective_type,
1831
1888
tuner.objective_metric_name_dict[estimator_name],
1832
1889
tuner.hyperparameter_ranges_dict()[estimator_name],
1890
+ tuner.instance_configs_dict.get(estimator_name, None),
1833
1891
)
1834
1892
for estimator_name in sorted(tuner.estimator_dict.keys())
1835
1893
]
1836
1894
1837
1895
return tuner_args
1838
1896
1897
+ @staticmethod
1898
+ def _prepare_hp_resource_config(
1899
+ instance_configs: List[InstanceConfig],
1900
+ instance_count: int,
1901
+ instance_type: str,
1902
+ volume_size: int,
1903
+ volume_kms_key: str,
1904
+ ):
1905
+ """Placeholder hpo resource config for one estimator of the tuner."""
1906
+ resource_config = {}
1907
+ if volume_kms_key is not None:
1908
+ resource_config["VolumeKmsKeyId"] = volume_kms_key
1909
+
1910
+ if instance_configs is None:
1911
+ resource_config["InstanceCount"] = instance_count
1912
+ resource_config["InstanceType"] = instance_type
1913
+ resource_config["VolumeSizeInGB"] = volume_size
1914
+ else:
1915
+ resource_config["InstanceConfigs"] = _TuningJob._prepare_instance_configs(
1916
+ instance_configs
1917
+ )
1918
+
1919
+ return resource_config
1920
+
1921
+ @staticmethod
1922
+ def _prepare_instance_configs(instance_configs):
1923
+ """Prepare instance config for create tuning request."""
1924
+ return [
1925
+ {
1926
+ "InstanceCount": config.instance_count,
1927
+ "InstanceType": config.instance_type,
1928
+ "VolumeSizeInGB": config.volume_size,
1929
+ }
1930
+ for config in instance_configs
1931
+ ]
1932
+
1839
1933
@staticmethod
1840
1934
def _prepare_training_config(
1841
1935
inputs,
@@ -1846,10 +1940,20 @@ def _prepare_training_config(
1846
1940
objective_type=None,
1847
1941
objective_metric_name=None,
1848
1942
parameter_ranges=None,
1943
+ instance_configs=None,
1849
1944
):
1850
1945
"""Prepare training config for one estimator."""
1851
1946
training_config = _Job._load_config(inputs, estimator)
1852
1947
1948
+ del training_config["resource_config"]
1949
+ training_config["hpo_resource_config"] = _TuningJob._prepare_hp_resource_config(
1950
+ instance_configs,
1951
+ estimator.instance_count,
1952
+ estimator.instance_type,
1953
+ estimator.volume_size,
1954
+ estimator.volume_kms_key,
1955
+ )
1956
+
1853
1957
training_config["input_mode"] = estimator.input_mode
1854
1958
training_config["metric_definitions"] = metric_definitions
1855
1959
0 commit comments