Skip to content

Commit 4d7dda5

Browse files
authored
Remove extra hyperparameters for 1P (aws#38)
1 parent aa14097 commit 4d7dda5

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

src/sagemaker/tuner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,13 @@ def prepare_for_training(self):
9494
for hyperparameter_name in self._hyperparameter_ranges.keys():
9595
self.static_hyperparameters.pop(hyperparameter_name, None)
9696

97-
# For attach() to know what estimator to use
98-
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = self.estimator.__class__.__name__
99-
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = self.estimator.__module__
97+
# For attach() to know what estimator to use for non-1P algorithms
98+
# (1P algorithms don't accept extra hyperparameters)
99+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
100+
101+
if not isinstance(self.estimator, AmazonAlgorithmEstimatorBase):
102+
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = self.estimator.__class__.__name__
103+
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = self.estimator.__module__
100104

101105
def fit(self, inputs, job_name=None, **kwargs):
102106
"""Start a hyperparameter tuning job.

tests/integ/test_tuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_fit_1p(sagemaker_session):
5252
# specify which hp you want to optimize over
5353
hyperparameter_ranges = {'extra_center_factor': IntegerParameter(1, 10),
5454
'mini_batch_size': IntegerParameter(10, 100),
55-
'local_lloyd_tol': ContinuousParameter(1.0, 2.0),
55+
'local_lloyd_tol': ContinuousParameter(0.5, 0.75),
5656
'local_lloyd_init_method': CategoricalParameter(['kmeans++', 'random'])}
5757
objective_metric_name = 'test:msd'
5858

tests/unit/test_tuner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def test_fit_1p(sagemaker_session, tuner):
121121

122122
_, _, tune_kwargs = sagemaker_session.tune.mock_calls[0]
123123

124-
assert len(tune_kwargs['static_hyperparameters']) == 6
125-
assert tune_kwargs['static_hyperparameters']['sagemaker_estimator_module'] == pca.__module__
124+
assert len(tune_kwargs['static_hyperparameters']) == 4
126125
assert tune_kwargs['static_hyperparameters']['extra_components'] == '5'
127126
assert len(tune_kwargs['parameter_ranges']['IntegerParameterRanges']) == 1
128127
assert tune_kwargs['job_name'].startswith('pca')

0 commit comments

Comments
 (0)