File tree Expand file tree Collapse file tree 3 files changed +9
-6
lines changed Expand file tree Collapse file tree 3 files changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -94,9 +94,13 @@ def prepare_for_training(self):
94
94
for hyperparameter_name in self ._hyperparameter_ranges .keys ():
95
95
self .static_hyperparameters .pop (hyperparameter_name , None )
96
96
97
- # For attach() to know what estimator to use
98
- self .static_hyperparameters [self .SAGEMAKER_ESTIMATOR_CLASS_NAME ] = self .estimator .__class__ .__name__
99
- self .static_hyperparameters [self .SAGEMAKER_ESTIMATOR_MODULE ] = self .estimator .__module__
97
+ # For attach() to know what estimator to use for non-1P algorithms
98
+ # (1P algorithms don't accept extra hyperparameters)
99
+ from sagemaker .amazon .amazon_estimator import AmazonAlgorithmEstimatorBase
100
+
101
+ if not isinstance (self .estimator , AmazonAlgorithmEstimatorBase ):
102
+ self .static_hyperparameters [self .SAGEMAKER_ESTIMATOR_CLASS_NAME ] = self .estimator .__class__ .__name__
103
+ self .static_hyperparameters [self .SAGEMAKER_ESTIMATOR_MODULE ] = self .estimator .__module__
100
104
101
105
def fit (self , inputs , job_name = None , ** kwargs ):
102
106
"""Start a hyperparameter tuning job.
Original file line number Diff line number Diff line change @@ -52,7 +52,7 @@ def test_fit_1p(sagemaker_session):
52
52
# specify which hp you want to optimize over
53
53
hyperparameter_ranges = {'extra_center_factor' : IntegerParameter (1 , 10 ),
54
54
'mini_batch_size' : IntegerParameter (10 , 100 ),
55
- 'local_lloyd_tol' : ContinuousParameter (1.0 , 2.0 ),
55
+ 'local_lloyd_tol' : ContinuousParameter (0.5 , 0.75 ),
56
56
'local_lloyd_init_method' : CategoricalParameter (['kmeans++' , 'random' ])}
57
57
objective_metric_name = 'test:msd'
58
58
Original file line number Diff line number Diff line change @@ -121,8 +121,7 @@ def test_fit_1p(sagemaker_session, tuner):
121
121
122
122
_ , _ , tune_kwargs = sagemaker_session .tune .mock_calls [0 ]
123
123
124
- assert len (tune_kwargs ['static_hyperparameters' ]) == 6
125
- assert tune_kwargs ['static_hyperparameters' ]['sagemaker_estimator_module' ] == pca .__module__
124
+ assert len (tune_kwargs ['static_hyperparameters' ]) == 4
126
125
assert tune_kwargs ['static_hyperparameters' ]['extra_components' ] == '5'
127
126
assert len (tune_kwargs ['parameter_ranges' ]['IntegerParameterRanges' ]) == 1
128
127
assert tune_kwargs ['job_name' ].startswith ('pca' )
You can’t perform that action at this time.
0 commit comments