File tree Expand file tree Collapse file tree 3 files changed +34
-12
lines changed Expand file tree Collapse file tree 3 files changed +34
-12
lines changed Original file line number Diff line number Diff line change @@ -6,6 +6,7 @@ CHANGELOG
6
6
======
7
7
8
8
* feature: ``PipelineModel ``: Create a Transformer from a PipelineModel
9
+ * bug-fix: ``AlgorithmEstimator ``: Make SupportedHyperParameters optional
9
10
10
11
1.18.4
11
12
======
Original file line number Diff line number Diff line change @@ -375,20 +375,23 @@ def _validate_and_set_default_hyperparameters(self):
375
375
raise ValueError ('Required hyperparameter: %s is not set' % name )
376
376
377
377
def _parse_hyperparameters (self ):
378
- hyperparameters = self .algorithm_spec ['TrainingSpecification' ]['SupportedHyperParameters' ]
379
378
definitions = {}
380
- for h in hyperparameters :
381
- parameter_type = h ['Type' ]
382
- name = h ['Name' ]
383
- parameter_class , parameter_range = self ._hyperparameter_range_and_class (
384
- parameter_type , h
385
- )
386
379
387
- definitions [name ] = {'spec' : h }
388
- if parameter_range :
389
- definitions [name ]['range' ] = parameter_range
390
- if parameter_class :
391
- definitions [name ]['class' ] = parameter_class
380
+ training_spec = self .algorithm_spec ['TrainingSpecification' ]
381
+ if 'SupportedHyperParameters' in training_spec :
382
+ hyperparameters = training_spec ['SupportedHyperParameters' ]
383
+ for h in hyperparameters :
384
+ parameter_type = h ['Type' ]
385
+ name = h ['Name' ]
386
+ parameter_class , parameter_range = self ._hyperparameter_range_and_class (
387
+ parameter_type , h
388
+ )
389
+
390
+ definitions [name ] = {'spec' : h }
391
+ if parameter_range :
392
+ definitions [name ]['range' ] = parameter_range
393
+ if parameter_class :
394
+ definitions [name ]['class' ] = parameter_class
392
395
393
396
return definitions
394
397
Original file line number Diff line number Diff line change @@ -913,3 +913,21 @@ def test_algorithm_encrypt_inter_container_traffic(sagemaker_session):
913
913
914
914
encrypt_inter_container_traffic = estimator .encrypt_inter_container_traffic
915
915
assert encrypt_inter_container_traffic is True
916
+
917
+
918
+ def test_algorithm_no_required_hyperparameters (sagemaker_session ):
919
+ some_algo = copy .deepcopy (DESCRIBE_ALGORITHM_RESPONSE )
920
+ del some_algo ['TrainingSpecification' ]['SupportedHyperParameters' ]
921
+
922
+ sagemaker_session .sagemaker_client .describe_algorithm = Mock (return_value = some_algo )
923
+
924
+ # Calling AlgorithmEstimator() with unset required hyperparameters
925
+ # should fail if they are required.
926
+ # Pass training and hyperparameters channels. This should work
927
+ assert AlgorithmEstimator (
928
+ algorithm_arn = 'arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees' ,
929
+ role = 'SageMakerRole' ,
930
+ train_instance_type = 'ml.m4.2xlarge' ,
931
+ train_instance_count = 1 ,
932
+ sagemaker_session = sagemaker_session ,
933
+ )
You can’t perform that action at this time.
0 commit comments