Skip to content

Commit 10dacec

Browse files
committed
Make inputs optional for hyperparameter tuning job.
1 parent 11d3fcf commit 10dacec

File tree

4 files changed

+24
-3
lines changed

4 files changed

+24
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def read(fname):
5353
],
5454

5555
# Declare minimal set for installation
56-
install_requires=['boto3>=1.9.38', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
56+
install_requires=['boto3>=1.9.45', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
5757
'urllib3 >=1.21', 'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5',
5858
'docker-compose>=1.23.0'],
5959

src/sagemaker/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,13 +345,15 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
345345
'TrainingInputMode': input_mode,
346346
},
347347
'RoleArn': role,
348-
'InputDataConfig': input_config,
349348
'OutputDataConfig': output_config,
350349
'ResourceConfig': resource_config,
351350
'StoppingCondition': stop_condition,
352351
}
353352
}
354353

354+
if input_config is not None:
355+
tune_request['TrainingJobDefinition']['InputDataConfig'] = input_config
356+
355357
if metric_definitions is not None:
356358
tune_request['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'] = metric_definitions
357359

src/sagemaker/tuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _prepare_for_training(self, job_name=None, include_cls_metadata=True):
213213
self.estimator.__class__.__name__)
214214
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps(self.estimator.__module__)
215215

216-
def fit(self, inputs, job_name=None, include_cls_metadata=True, **kwargs):
216+
def fit(self, inputs=None, job_name=None, include_cls_metadata=True, **kwargs):
217217
"""Start a hyperparameter tuning job.
218218
219219
Args:

tests/unit/test_tuner.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import copy
1616
import json
1717

18+
import os
1819
import pytest
1920
from mock import Mock
2021

@@ -25,6 +26,8 @@
2526
from sagemaker.tuner import _ParameterRange, ContinuousParameter, IntegerParameter, CategoricalParameter, \
2627
HyperparameterTuner, _TuningJob
2728
from sagemaker.mxnet import MXNet
29+
30+
DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
2831
MODEL_DATA = "s3://bucket/model.tar.gz"
2932

3033
JOB_NAME = 'tuning_job'
@@ -474,6 +477,22 @@ def test_delete_endpoint(tuner):
474477
tuner.sagemaker_session.delete_endpoint.assert_called_with(JOB_NAME)
475478

476479

480+
def test_fit_no_inputs(tuner, sagemaker_session):
481+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py')
482+
tuner.estimator = MXNet(entry_point=script_path,
483+
role=ROLE,
484+
framework_version=FRAMEWORK_VERSION,
485+
train_instance_count=TRAIN_INSTANCE_COUNT,
486+
train_instance_type=TRAIN_INSTANCE_TYPE,
487+
sagemaker_session=sagemaker_session)
488+
489+
tuner.fit()
490+
491+
_, _, tune_kwargs = sagemaker_session.tune.mock_calls[0]
492+
493+
assert tune_kwargs['input_config'] is None
494+
495+
477496
#################################################################################
478497
# _ParameterRange Tests
479498

0 commit comments

Comments
 (0)