@@ -89,8 +89,9 @@ class HyperparameterTuner(object):
89
89
DEFAULT_ESTIMATOR_MODULE = 'sagemaker.estimator'
90
90
DEFAULT_ESTIMATOR_CLS_NAME = 'Estimator'
91
91
92
- def __init__ (self , estimator , objective_metric_name , hyperparameter_ranges , metric_definitions , strategy = 'Bayesian' ,
93
- objective_type = 'Maximize' , max_jobs = 1 , max_parallel_jobs = 1 , base_tuning_job_name = None ):
92
+ def __init__ (self , estimator , objective_metric_name , hyperparameter_ranges , metric_definitions = None ,
93
+ strategy = 'Bayesian' , objective_type = 'Maximize' , max_jobs = 1 , max_parallel_jobs = 1 ,
94
+ base_tuning_job_name = None ):
94
95
self ._hyperparameter_ranges = hyperparameter_ranges
95
96
if self ._hyperparameter_ranges is None or len (self ._hyperparameter_ranges ) == 0 :
96
97
raise ValueError ('Need to specify hyperparameter ranges' )
@@ -102,7 +103,6 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
102
103
103
104
self .strategy = strategy
104
105
self .objective_type = objective_type
105
-
106
106
self .max_jobs = max_jobs
107
107
self .max_parallel_jobs = max_parallel_jobs
108
108
@@ -124,7 +124,8 @@ def prepare_for_training(self, job_name=None):
124
124
# For attach() to know what estimator to use for non-1P algorithms
125
125
# (1P algorithms don't accept extra hyperparameters)
126
126
if not isinstance (self .estimator , AmazonAlgorithmEstimatorBase ):
127
- self .static_hyperparameters [self .SAGEMAKER_ESTIMATOR_CLASS_NAME ] = json .dumps (self .estimator .__class__ .__name__ )
127
+ self .static_hyperparameters [self .SAGEMAKER_ESTIMATOR_CLASS_NAME ] = json .dumps (
128
+ self .estimator .__class__ .__name__ )
128
129
self .static_hyperparameters [self .SAGEMAKER_ESTIMATOR_MODULE ] = json .dumps (self .estimator .__module__ )
129
130
130
131
def fit (self , inputs , job_name = None , ** kwargs ):
@@ -150,7 +151,7 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim
150
151
sagemaker_session = sagemaker_session or Session ()
151
152
152
153
if job_details is None :
153
- job_details = sagemaker_session .sagemaker_client \
154
+ job_details = sagemaker_session .sagemaker_client \
154
155
.describe_hyper_parameter_tuning_job (HyperParameterTuningJobName = tuning_job_name )
155
156
156
157
estimator_cls = cls ._prepare_estimator_cls (estimator_cls , job_details ['TrainingJobDefinition' ])
@@ -249,7 +250,7 @@ def _prepare_estimator_cls(cls, estimator_cls, training_details):
249
250
250
251
# Then try to derive the estimator from the image name for 1P algorithms
251
252
image_name = training_details ['AlgorithmSpecification' ]['TrainingImage' ]
252
- algorithm = image_name [image_name .find ('/' )+ 1 :image_name .find (':' )]
253
+ algorithm = image_name [image_name .find ('/' ) + 1 :image_name .find (':' )]
253
254
if algorithm in AMAZON_ESTIMATOR_CLS_NAMES :
254
255
cls_name = AMAZON_ESTIMATOR_CLS_NAMES [algorithm ]
255
256
return getattr (importlib .import_module (AMAZON_ESTIMATOR_MODULE ), cls_name )
0 commit comments