Skip to content

Commit bae98db

Browse files
authored
Support VPC config for hyperparameter tuning and bump version to 1.17.2 (#598)
1 parent 0cd7aa1 commit bae98db

File tree

7 files changed

+82
-6
lines changed

7 files changed

+82
-6
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.17.2
6+
======
7+
8+
* feature: HyperparameterTuner: support VPC config
9+
510
1.17.1
611
======
712

doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __getattr__(cls, name):
3232
'numpy', 'scipy', 'scipy.sparse']
3333
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
3434

35-
version = '1.17.1'
35+
version = '1.17.2'
3636
project = u'sagemaker'
3737

3838
# Add any Sphinx extension module names here, as strings. They can be extensions

src/sagemaker/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@
3939
from sagemaker.session import s3_input # noqa: F401
4040
from sagemaker.session import get_execution_role # noqa: F401
4141

42-
__version__ = '1.17.1'
42+
__version__ = '1.17.2'

src/sagemaker/session.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
358358
static_hyperparameters, input_mode, metric_definitions,
359359
role, input_config, output_config, resource_config, stop_condition, tags,
360360
warm_start_config, enable_network_isolation=False, image=None, algorithm_arn=None,
361-
early_stopping_type='Off', encrypt_inter_container_traffic=False):
361+
early_stopping_type='Off', encrypt_inter_container_traffic=False, vpc_config=None):
362362
"""Create an Amazon SageMaker hyperparameter tuning job
363363
364364
Args:
@@ -408,8 +408,14 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
408408
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be attempted.
409409
If set to 'Auto', early stopping of some training jobs may happen, but is not guaranteed to.
410410
encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers
411-
is encrypted for the training jobs started for this hyperparameter tuning job. Set to ``False``
412-
by default.
411+
is encrypted for the training jobs started for this hyperparameter tuning job (default: ``False``).
412+
vpc_config (dict): Contains values for VpcConfig (default: None):
413+
414+
* subnets (list[str]): List of subnet ids.
415+
The key in vpc_config is 'Subnets'.
416+
* security_group_ids (list[str]): List of security group ids.
417+
The key in vpc_config is 'SecurityGroupIds'.
418+
413419
"""
414420
tune_request = {
415421
'HyperParameterTuningJobName': job_name,
@@ -457,6 +463,9 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
457463
if tags is not None:
458464
tune_request['Tags'] = tags
459465

466+
if vpc_config is not None:
467+
tune_request['TrainingJobDefinition']['VpcConfig'] = vpc_config
468+
460469
if enable_network_isolation:
461470
tune_request['TrainingJobDefinition']['EnableNetworkIsolation'] = True
462471

src/sagemaker/tuner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,6 @@ def start_new(cls, tuner, inputs):
634634
tuner_args['warm_start_config'] = warm_start_config_req
635635
tuner_args['early_stopping_type'] = tuner.early_stopping_type
636636

637-
del tuner_args['vpc_config']
638637
if isinstance(tuner.estimator, sagemaker.algorithm.AlgorithmEstimator):
639638
tuner_args['algorithm_arn'] = tuner.estimator.algorithm_arn
640639
else:

tests/integ/test_tuner.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tests.integ import DATA_DIR, PYTHON_VERSION, TUNING_DEFAULT_TIMEOUT_MINUTES
2626
from tests.integ.record_set import prepare_record_set_from_local_files
2727
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
28+
from tests.integ import vpc_test_utils
2829

2930
from sagemaker import KMeans, LDA, RandomCutForest
3031
from sagemaker.amazon.amazon_estimator import registry
@@ -491,6 +492,52 @@ def test_tuning_tf(sagemaker_session):
491492
assert dict_result == list_result
492493

493494

495+
@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.")
496+
def test_tuning_tf_vpc_multi(sagemaker_session):
497+
"""Test Tensorflow multi-instance using the same VpcConfig for training and inference"""
498+
instance_type = 'ml.c4.xlarge'
499+
instance_count = 2
500+
501+
script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py')
502+
503+
ec2_client = sagemaker_session.boto_session.client('ec2')
504+
subnet_ids, security_group_id = vpc_test_utils.get_or_create_vpc_resources(ec2_client,
505+
sagemaker_session.boto_region_name)
506+
vpc_test_utils.setup_security_group_for_encryption(ec2_client, security_group_id)
507+
508+
estimator = TensorFlow(entry_point=script_path,
509+
role='SageMakerRole',
510+
training_steps=1,
511+
evaluation_steps=1,
512+
hyperparameters={'input_tensor_name': 'inputs'},
513+
train_instance_count=instance_count,
514+
train_instance_type=instance_type,
515+
sagemaker_session=sagemaker_session,
516+
base_job_name='test-vpc-tf',
517+
subnets=subnet_ids,
518+
security_group_ids=[security_group_id],
519+
encrypt_inter_container_traffic=True)
520+
521+
inputs = sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf_iris')
522+
hyperparameter_ranges = {'learning_rate': ContinuousParameter(0.05, 0.2)}
523+
524+
objective_metric_name = 'loss'
525+
metric_definitions = [{'Name': 'loss', 'Regex': 'loss = ([0-9\\.]+)'}]
526+
527+
tuner = HyperparameterTuner(estimator, objective_metric_name, hyperparameter_ranges,
528+
metric_definitions,
529+
objective_type='Minimize', max_jobs=2, max_parallel_jobs=2)
530+
531+
tuning_job_name = unique_name_from_base('tune-tf', max_length=32)
532+
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
533+
tuner.fit(inputs, job_name=tuning_job_name)
534+
535+
print('Started hyperparameter tuning job with name:' + tuning_job_name)
536+
537+
time.sleep(15)
538+
tuner.wait()
539+
540+
494541
@pytest.mark.continuous_testing
495542
def test_tuning_chainer(sagemaker_session):
496543
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):

tests/unit/test_tuner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,22 @@ def test_fit_pca_with_early_stopping(sagemaker_session, tuner):
256256
assert tune_kwargs['early_stopping_type'] == 'Auto'
257257

258258

259+
def test_fit_mxnet_with_vpc_config(sagemaker_session, tuner):
260+
subnets = ['foo']
261+
security_group_ids = ['bar']
262+
263+
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
264+
base_job_name='pca', sagemaker_session=sagemaker_session,
265+
subnets=subnets, security_group_ids=security_group_ids)
266+
tuner.estimator = pca
267+
268+
records = RecordSet(s3_data=INPUTS, num_records=1, feature_dim=1)
269+
tuner.fit(records, mini_batch_size=9999)
270+
271+
_, _, tune_kwargs = sagemaker_session.tune.mock_calls[0]
272+
assert tune_kwargs['vpc_config'] == {'Subnets': subnets, 'SecurityGroupIds': security_group_ids}
273+
274+
259275
def test_fit_pca_with_inter_container_traffic_encryption_flag(sagemaker_session, tuner):
260276
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
261277
base_job_name='pca', sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)