Skip to content

fix: update TrainingInputMode with s3_input InputMode #776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,12 @@ def start_new(cls, estimator, inputs):
train_args['tags'] = estimator.tags
train_args['metric_definitions'] = estimator.metric_definitions

if isinstance(inputs, s3_input):
if 'InputMode' in inputs.config:
logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.'
.format(inputs.config['InputMode']))
train_args['input_mode'] = inputs.config['InputMode']

if estimator.enable_network_isolation():
train_args['enable_network_isolation'] = True

Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import importlib
import inspect
import json
import logging
from enum import Enum

import sagemaker
Expand All @@ -26,6 +27,7 @@
from sagemaker.parameter import (CategoricalParameter, ContinuousParameter,
IntegerParameter, ParameterRange)
from sagemaker.session import Session
from sagemaker.session import s3_input
from sagemaker.utils import base_name_from_image, name_from_base, to_str

AMAZON_ESTIMATOR_MODULE = 'sagemaker'
Expand Down Expand Up @@ -640,6 +642,12 @@ def start_new(cls, tuner, inputs):
tuner_args['warm_start_config'] = warm_start_config_req
tuner_args['early_stopping_type'] = tuner.early_stopping_type

if isinstance(inputs, s3_input):
if 'InputMode' in inputs.config:
logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.'
.format(inputs.config['InputMode']))
tuner_args['input_mode'] = inputs.config['InputMode']

if isinstance(tuner.estimator, sagemaker.algorithm.AlgorithmEstimator):
tuner_args['algorithm_arn'] = tuner.estimator.algorithm_arn
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/scripts/run-notebook-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ aws s3 --region us-west-2 cp ./dist/sagemaker-*.tar.gz s3://sagemaker-python-sdk
aws s3 cp s3://sagemaker-mead-cli/mead-nb-test.tar.gz mead-nb-test.tar.gz
tar -xzf mead-nb-test.tar.gz
git clone --depth 1 https://github.com/awslabs/amazon-sagemaker-examples.git
JAVA_HOME=$(get-java-home)
export JAVA_HOME=$(get-java-home)
echo "set JAVA_HOME=$JAVA_HOME"
SAGEMAKER_ROLE_ARN=$(get-sagemaker-role-arn)
export SAGEMAKER_ROLE_ARN=$(get-sagemaker-role-arn)
echo "set SAGEMAKER_ROLE_ARN=$SAGEMAKER_ROLE_ARN"
./runtime/bin/mead-run-nb-test \
--instance-type ml.c4.8xlarge \
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,17 @@ def test_augmented_manifest(sagemaker_session):
assert s3_data_source['AttributeNames'] == ['foo', 'bar']


def test_s3_input_mode(sagemaker_session):
expected_input_mode = 'Pipe'
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
enable_cloudwatch_metrics=True)
fw.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode))

actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode']
assert actual_input_mode == expected_input_mode


def test_shuffle_config(sagemaker_session):
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sagemaker.tuner import (_TuningJob, create_identical_dataset_and_algorithm_tuner,
create_transfer_learning_tuner, HyperparameterTuner, WarmStartConfig,
WarmStartTypes)
from sagemaker.session import s3_input

DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
MODEL_DATA = "s3://bucket/model.tar.gz"
Expand Down Expand Up @@ -286,6 +287,31 @@ def test_fit_mxnet_with_vpc_config(sagemaker_session, tuner):
assert tune_kwargs['vpc_config'] == {'Subnets': subnets, 'SecurityGroupIds': security_group_ids}


def test_s3_input_mode(sagemaker_session, tuner):
expected_input_mode = 'Pipe'

script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py')
mxnet = MXNet(entry_point=script_path,
role=ROLE,
framework_version=FRAMEWORK_VERSION,
train_instance_count=TRAIN_INSTANCE_COUNT,
train_instance_type=TRAIN_INSTANCE_TYPE,
sagemaker_session=sagemaker_session)
tuner.estimator = mxnet

tags = [{'Name': 'some-tag-without-a-value'}]
tuner.tags = tags

hyperparameter_ranges = {'num_components': IntegerParameter(2, 4),
'algorithm_mode': CategoricalParameter(['regular', 'randomized'])}
tuner._hyperparameter_ranges = hyperparameter_ranges

tuner.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode))

actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode']
assert actual_input_mode == expected_input_mode


def test_fit_pca_with_inter_container_traffic_encryption_flag(sagemaker_session, tuner):
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
base_job_name='pca', sagemaker_session=sagemaker_session,
Expand Down