Skip to content

Commit faccfb2

Browse files
nadiayalaurenyu
authored andcommitted
Make inputs optional for hyperparameter tuning jobs (#490)
1 parent c32dec0 commit faccfb2

File tree

5 files changed

+27
-3
lines changed

5 files changed

+27
-3
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ CHANGELOG
1111
* doc-fix: Fix typos in tensorflow serving documentation
1212
* doc-fix: Add estimator base classes to API docs
1313
* feature: HyperparameterTuner: add support for Automatic Model Tuning's Warm Start Jobs
14+
* feature: HyperparameterTuner: Make input channels optional
1415

1516
1.14.2
1617
======

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def read(fname):
5353
],
5454

5555
# Declare minimal set for installation
56-
install_requires=['boto3>=1.9.38', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
56+
install_requires=['boto3>=1.9.45', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
5757
'urllib3 >=1.21', 'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5',
5858
'docker-compose>=1.23.0'],
5959

src/sagemaker/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,13 +348,15 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
348348
'TrainingInputMode': input_mode,
349349
},
350350
'RoleArn': role,
351-
'InputDataConfig': input_config,
352351
'OutputDataConfig': output_config,
353352
'ResourceConfig': resource_config,
354353
'StoppingCondition': stop_condition,
355354
}
356355
}
357356

357+
if input_config is not None:
358+
tune_request['TrainingJobDefinition']['InputDataConfig'] = input_config
359+
358360
if warm_start_config:
359361
tune_request['WarmStartConfig'] = warm_start_config
360362

src/sagemaker/tuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _prepare_for_training(self, job_name=None, include_cls_metadata=True):
327327
self.estimator.__class__.__name__)
328328
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps(self.estimator.__module__)
329329

330-
def fit(self, inputs, job_name=None, include_cls_metadata=True, **kwargs):
330+
def fit(self, inputs=None, job_name=None, include_cls_metadata=True, **kwargs):
331331
"""Start a hyperparameter tuning job.
332332
333333
Args:

tests/unit/test_tuner.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import copy
1616
import json
1717

18+
import os
1819
import pytest
1920
from mock import Mock
2021

@@ -26,6 +27,8 @@
2627
HyperparameterTuner, _TuningJob, WarmStartConfig, create_identical_dataset_and_algorithm_tuner, \
2728
create_transfer_learning_tuner, WarmStartTypes
2829
from sagemaker.mxnet import MXNet
30+
31+
DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
2932
MODEL_DATA = "s3://bucket/model.tar.gz"
3033

3134
JOB_NAME = 'tuning_job'
@@ -488,6 +491,22 @@ def test_delete_endpoint(tuner):
488491
tuner.sagemaker_session.delete_endpoint.assert_called_with(JOB_NAME)
489492

490493

494+
def test_fit_no_inputs(tuner, sagemaker_session):
495+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py')
496+
tuner.estimator = MXNet(entry_point=script_path,
497+
role=ROLE,
498+
framework_version=FRAMEWORK_VERSION,
499+
train_instance_count=TRAIN_INSTANCE_COUNT,
500+
train_instance_type=TRAIN_INSTANCE_TYPE,
501+
sagemaker_session=sagemaker_session)
502+
503+
tuner.fit()
504+
505+
_, _, tune_kwargs = sagemaker_session.tune.mock_calls[0]
506+
507+
assert tune_kwargs['input_config'] is None
508+
509+
491510
def test_identical_dataset_and_algorithm_tuner(sagemaker_session):
492511
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
493512
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job',
@@ -523,6 +542,8 @@ def test_transfer_learning_tuner(sagemaker_session):
523542
assert parent_tuner.warm_start_config.type == WarmStartTypes.TRANSFER_LEARNING
524543
assert parent_tuner.warm_start_config.parents == {tuner.latest_tuning_job.name, "p1", "p2"}
525544
assert parent_tuner.estimator == tuner.estimator
545+
546+
526547
#################################################################################
527548
# _ParameterRange Tests
528549

0 commit comments

Comments
 (0)