@@ -383,6 +383,38 @@ def to_input_req(self):
383
383
}
384
384
385
385
386
+ class InstanceConfig :
387
+ """Instance configuration for training jobs started by hyperparameter tuning.
388
+
389
+ Contains the configuration(s) for one or more resources for processing hyperparameter jobs.
390
+ These resources include compute instances and storage volumes to use in model training jobs
391
+ launched by hyperparameter tuning jobs.
392
+ """
393
+
394
+ def __init__ (
395
+ self ,
396
+ instance_count : Union [int , PipelineVariable ] = None ,
397
+ instance_type : Union [str , PipelineVariable ] = None ,
398
+ volume_size : Union [int , PipelineVariable ] = 30 ,
399
+ ):
400
+ """Creates a ``InstanceConfig`` instance.
401
+
402
+ It takes instance configuration information for training
403
+ jobs that are created as the result of a hyperparameter tuning job.
404
+
405
+ Args:
406
+ * instance_count (str or PipelineVariable): The number of compute instances of type
407
+ InstanceType to use. For distributed training, select a value greater than 1.
408
+ * instance_type (str or PipelineVariable):
409
+ The instance type used to run hyperparameter optimization tuning jobs.
410
+ * volume_size (int or PipelineVariable): The volume size in GB of the data to be
411
+ processed for hyperparameter optimization
412
+ """
413
+ self .instance_count = instance_count
414
+ self .instance_type = instance_type
415
+ self .volume_size = volume_size
416
+
417
+
386
418
class HyperparameterTuner (object ):
387
419
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388
420
@@ -482,14 +514,14 @@ def __init__(
482
514
self .estimator = None
483
515
self .objective_metric_name = None
484
516
self ._hyperparameter_ranges = None
517
+ self .static_hyperparameters = None
485
518
self .metric_definitions = None
486
519
self .estimator_dict = {estimator_name : estimator }
487
520
self .objective_metric_name_dict = {estimator_name : objective_metric_name }
488
521
self ._hyperparameter_ranges_dict = {estimator_name : hyperparameter_ranges }
489
522
self .metric_definitions_dict = (
490
523
{estimator_name : metric_definitions } if metric_definitions is not None else {}
491
524
)
492
- self .static_hyperparameters = None
493
525
else :
494
526
self .estimator = estimator
495
527
self .objective_metric_name = objective_metric_name
@@ -521,6 +553,31 @@ def __init__(
521
553
self .warm_start_config = warm_start_config
522
554
self .early_stopping_type = early_stopping_type
523
555
self .random_seed = random_seed
556
+ self .instance_configs_dict = None
557
+ self .instance_configs = None
558
+
559
+ def override_resource_config (
560
+ self , instance_configs : Union [List [InstanceConfig ], Dict [str , List [InstanceConfig ]]]
561
+ ):
562
+ """Override the instance configuration of the estimators used by the tuner.
563
+
564
+ Args:
565
+ instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
566
+ The InstanceConfigs to use as an override for the instance configuration
567
+ of the estimator. ``None`` will remove the override.
568
+ """
569
+ if isinstance (instance_configs , dict ):
570
+ self ._validate_dict_argument (
571
+ name = "instance_configs" ,
572
+ value = instance_configs ,
573
+ allowed_keys = list (self .estimator_dict .keys ()),
574
+ )
575
+ self .instance_configs_dict = instance_configs
576
+ else :
577
+ self .instance_configs = instance_configs
578
+ if self .estimator_dict is not None and self .estimator_dict .keys ():
579
+ estimator_names = list (self .estimator_dict .keys ())
580
+ self .instance_configs_dict = {estimator_names [0 ]: instance_configs }
524
581
525
582
def _prepare_for_tuning (self , job_name = None , include_cls_metadata = False ):
526
583
"""Prepare the tuner instance for tuning (fit)."""
@@ -589,7 +646,6 @@ def _prepare_job_name_for_tuning(self, job_name=None):
589
646
590
647
def _prepare_static_hyperparameters_for_tuning (self , include_cls_metadata = False ):
591
648
"""Prepare static hyperparameters for all estimators before tuning."""
592
- self .static_hyperparameters = None
593
649
if self .estimator is not None :
594
650
self .static_hyperparameters = self ._prepare_static_hyperparameters (
595
651
self .estimator , self ._hyperparameter_ranges , include_cls_metadata
@@ -1817,6 +1873,7 @@ def _get_tuner_args(cls, tuner, inputs):
1817
1873
estimator = tuner .estimator ,
1818
1874
static_hyperparameters = tuner .static_hyperparameters ,
1819
1875
metric_definitions = tuner .metric_definitions ,
1876
+ instance_configs = tuner .instance_configs ,
1820
1877
)
1821
1878
1822
1879
if tuner .estimator_dict is not None :
@@ -1830,12 +1887,51 @@ def _get_tuner_args(cls, tuner, inputs):
1830
1887
tuner .objective_type ,
1831
1888
tuner .objective_metric_name_dict [estimator_name ],
1832
1889
tuner .hyperparameter_ranges_dict ()[estimator_name ],
1890
+ tuner .instance_configs_dict .get (estimator_name , None )
1891
+ if tuner .instance_configs_dict is not None
1892
+ else None ,
1833
1893
)
1834
1894
for estimator_name in sorted (tuner .estimator_dict .keys ())
1835
1895
]
1836
1896
1837
1897
return tuner_args
1838
1898
1899
+ @staticmethod
1900
+ def _prepare_hp_resource_config (
1901
+ instance_configs : List [InstanceConfig ],
1902
+ instance_count : int ,
1903
+ instance_type : str ,
1904
+ volume_size : int ,
1905
+ volume_kms_key : str ,
1906
+ ):
1907
+ """Placeholder hpo resource config for one estimator of the tuner."""
1908
+ resource_config = {}
1909
+ if volume_kms_key is not None :
1910
+ resource_config ["VolumeKmsKeyId" ] = volume_kms_key
1911
+
1912
+ if instance_configs is None :
1913
+ resource_config ["InstanceCount" ] = instance_count
1914
+ resource_config ["InstanceType" ] = instance_type
1915
+ resource_config ["VolumeSizeInGB" ] = volume_size
1916
+ else :
1917
+ resource_config ["InstanceConfigs" ] = _TuningJob ._prepare_instance_configs (
1918
+ instance_configs
1919
+ )
1920
+
1921
+ return resource_config
1922
+
1923
+ @staticmethod
1924
+ def _prepare_instance_configs (instance_configs : List [InstanceConfig ]):
1925
+ """Prepare instance config for create tuning request."""
1926
+ return [
1927
+ InstanceConfig (
1928
+ instance_count = config .instance_count ,
1929
+ instance_type = config .instance_type ,
1930
+ volume_size = config .volume_size ,
1931
+ )
1932
+ for config in instance_configs
1933
+ ]
1934
+
1839
1935
@staticmethod
1840
1936
def _prepare_training_config (
1841
1937
inputs ,
@@ -1846,10 +1942,20 @@ def _prepare_training_config(
1846
1942
objective_type = None ,
1847
1943
objective_metric_name = None ,
1848
1944
parameter_ranges = None ,
1945
+ instance_configs = None ,
1849
1946
):
1850
1947
"""Prepare training config for one estimator."""
1851
1948
training_config = _Job ._load_config (inputs , estimator )
1852
1949
1950
+ del training_config ["resource_config" ]
1951
+ training_config ["hpo_resource_config" ] = _TuningJob ._prepare_hp_resource_config (
1952
+ instance_configs ,
1953
+ estimator .instance_count ,
1954
+ estimator .instance_type ,
1955
+ estimator .volume_size ,
1956
+ estimator .volume_kms_key ,
1957
+ )
1958
+
1853
1959
training_config ["input_mode" ] = estimator .input_mode
1854
1960
training_config ["metric_definitions" ] = metric_definitions
1855
1961
0 commit comments