Skip to content

Commit 3f7bd59

Browse files
authored
Make SupportedHyperParameters optional. (#695)
1 parent 1787f78 commit 3f7bd59

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CHANGELOG
66
======
77

88
* feature: ``PipelineModel``: Create a Transformer from a PipelineModel
9+
* bug-fix: ``AlgorithmEstimator``: Make SupportedHyperParameters optional
910

1011
1.18.4
1112
======

src/sagemaker/algorithm.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -375,20 +375,23 @@ def _validate_and_set_default_hyperparameters(self):
375375
raise ValueError('Required hyperparameter: %s is not set' % name)
376376

377377
def _parse_hyperparameters(self):
378-
hyperparameters = self.algorithm_spec['TrainingSpecification']['SupportedHyperParameters']
379378
definitions = {}
380-
for h in hyperparameters:
381-
parameter_type = h['Type']
382-
name = h['Name']
383-
parameter_class, parameter_range = self._hyperparameter_range_and_class(
384-
parameter_type, h
385-
)
386379

387-
definitions[name] = {'spec': h}
388-
if parameter_range:
389-
definitions[name]['range'] = parameter_range
390-
if parameter_class:
391-
definitions[name]['class'] = parameter_class
380+
training_spec = self.algorithm_spec['TrainingSpecification']
381+
if 'SupportedHyperParameters' in training_spec:
382+
hyperparameters = training_spec['SupportedHyperParameters']
383+
for h in hyperparameters:
384+
parameter_type = h['Type']
385+
name = h['Name']
386+
parameter_class, parameter_range = self._hyperparameter_range_and_class(
387+
parameter_type, h
388+
)
389+
390+
definitions[name] = {'spec': h}
391+
if parameter_range:
392+
definitions[name]['range'] = parameter_range
393+
if parameter_class:
394+
definitions[name]['class'] = parameter_class
392395

393396
return definitions
394397

tests/unit/test_algorithm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,21 @@ def test_algorithm_encrypt_inter_container_traffic(sagemaker_session):
913913

914914
encrypt_inter_container_traffic = estimator.encrypt_inter_container_traffic
915915
assert encrypt_inter_container_traffic is True
916+
917+
918+
def test_algorithm_no_required_hyperparameters(sagemaker_session):
919+
some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE)
920+
del some_algo['TrainingSpecification']['SupportedHyperParameters']
921+
922+
sagemaker_session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo)
923+
924+
# Calling AlgorithmEstimator() with unset required hyperparameters
925+
# should fail if they are required.
926+
# Pass training and hyperparameters channels. This should work
927+
assert AlgorithmEstimator(
928+
algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees',
929+
role='SageMakerRole',
930+
train_instance_type='ml.m4.2xlarge',
931+
train_instance_count=1,
932+
sagemaker_session=sagemaker_session,
933+
)

0 commit comments

Comments
 (0)