Skip to content

feature: network isolation mode in training #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 21, 2019
Merged
71 changes: 52 additions & 19 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import sagemaker
from sagemaker.analytics import TrainingJobAnalytics
from sagemaker.fw_utils import (create_image_uri, tar_and_upload_dir, parse_s3_url, UploadedCode,
from sagemaker.fw_utils import (create_image_uri, tar_and_upload_dir, upload_file, parse_s3_url, UploadedCode,
validate_source_dir)
from sagemaker.job import _Job
from sagemaker.local import LocalSession
Expand Down Expand Up @@ -51,8 +51,9 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
def __init__(self, role, train_instance_count, train_instance_type,
train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60, input_mode='File',
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None,
subnets=None, security_group_ids=None, model_uri=None, model_channel_name='model',
metric_definitions=None, encrypt_inter_container_traffic=False):
subnets=None, security_group_ids=None, enable_network_isolation=False, model_uri=None,
model_channel_name='model', code_uri=None, code_channel_name='code', metric_definitions=None,
encrypt_inter_container_traffic=False):
"""Initialize an ``EstimatorBase`` instance.

Args:
Expand Down Expand Up @@ -89,6 +90,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config.
security_group_ids (list[str]): List of security group ids. If not specified training job will be created
without VPC config.
enable_network_isolation (bool): Specifies whether container will run in network isolation mode. Network
isolation mode restricts the container access to outside networks (such as the internet). If True,
a code channel will be created for any user entry script for training. Also known as internet-free
mode (default: `False`).
model_uri (str): URI where a pre-trained model is stored, either locally or in S3 (default: None). If
specified, the estimator will create a channel pointing to the model so the training job can download
it. This model can be a 'model.tar.gz' from a previous training job, or other artifacts coming from a
Expand All @@ -99,6 +104,9 @@ def __init__(self, role, train_instance_count, train_instance_type,

More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
model_channel_name (str): Name of the channel where 'model_uri' will be downloaded (default: 'model').
code_uri (str): URI where user entry script is stored, either locally or in S3 (default: None).
code_channel_name (str): Name of the channel where 'code_uri` will be downloaded in network isolation mode
(default: `code`).
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
the regular expression used to extract the metric from the logs. This should be defined only
Expand All @@ -115,8 +123,11 @@ def __init__(self, role, train_instance_count, train_instance_type,
self.input_mode = input_mode
self.tags = tags
self.metric_definitions = metric_definitions
self._enable_network_isolation = enable_network_isolation
self.model_uri = model_uri
self.model_channel_name = model_channel_name
self.code_uri = code_uri
self.code_channel_name = code_channel_name

if self.train_instance_type in ('local', 'local_gpu'):
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
Expand Down Expand Up @@ -170,7 +181,7 @@ def enable_network_isolation(self):
Returns:
bool: Whether this Estimator needs network isolation or not.
"""
return False
return self._enable_network_isolation

def _prepare_for_training(self, job_name=None):
"""Set any values in the estimator that need to be set before training.
Expand Down Expand Up @@ -611,8 +622,8 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60,
input_mode='File', output_path=None, output_kms_key=None, base_job_name=None,
sagemaker_session=None, hyperparameters=None, tags=None, subnets=None, security_group_ids=None,
model_uri=None, model_channel_name='model', metric_definitions=None,
encrypt_inter_container_traffic=False):
enable_network_isolation=False, model_uri=None, model_channel_name='model', code_uri=None,
code_channel_name='code', metric_definitions=None, encrypt_inter_container_traffic=False):
"""Initialize an ``Estimator`` instance.

Args:
Expand Down Expand Up @@ -653,6 +664,10 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config.
security_group_ids (list[str]): List of security group ids. If not specified training job will be created
without VPC config.
enable_network_isolation (bool): Specifies whether container will run in network isolation mode. Network
isolation mode restricts the container access to outside networks (such as the internet). If True,
a code channel will be created for any user entry script for training. Also known as internet-free
mode (default: `False`).
model_uri (str): URI where a pre-trained model is stored, either locally or in S3 (default: None). If
specified, the estimator will create a channel pointing to the model so the training job can download
it. This model can be a 'model.tar.gz' from a previous training job, or other artifacts coming from a
Expand All @@ -663,6 +678,9 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,

More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
model_channel_name (str): Name of the channel where 'model_uri' will be downloaded (default: 'model').
code_uri (str): URI where user entry script is stored, either locally or in S3 (default: None).
code_channel_name (str): Name of the channel where 'code_uri` will be downloaded in network isolation mode
(default: `code`).
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
the regular expression used to extract the metric from the logs. This should be defined only
Expand All @@ -675,8 +693,10 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
super(Estimator, self).__init__(role, train_instance_count, train_instance_type,
train_volume_size, train_volume_kms_key, train_max_run, input_mode,
output_path, output_kms_key, base_job_name, sagemaker_session,
tags, subnets, security_group_ids, model_uri=model_uri,
model_channel_name=model_channel_name, metric_definitions=metric_definitions,
tags, subnets, security_group_ids,
enable_network_isolation=enable_network_isolation, model_uri=model_uri,
model_channel_name=model_channel_name, code_uri=code_uri,
code_channel_name=code_channel_name, metric_definitions=metric_definitions,
encrypt_inter_container_traffic=encrypt_inter_container_traffic)

def train_image(self):
Expand Down Expand Up @@ -852,6 +872,12 @@ def _prepare_for_training(self, job_name=None):

code_dir = 'file://' + self.source_dir
script = self.entry_point
elif self.enable_network_isolation() and self.entry_point:
relative_code_location = 'input/data/code'
self.uploaded_code = self._stage_user_code_in_s3(tar=False, s3_location=relative_code_location)
code_dir = "/opt/ml/{}".format(relative_code_location)
script = self.uploaded_code.script_name
self.code_uri = self.uploaded_code.s3_prefix
else:
self.uploaded_code = self._stage_user_code_in_s3()
code_dir = self.uploaded_code.s3_prefix
Expand All @@ -865,7 +891,7 @@ def _prepare_for_training(self, job_name=None):
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name

def _stage_user_code_in_s3(self):
def _stage_user_code_in_s3(self, tar=True, s3_location='source'):
"""Upload the user training script to s3 and return the location.

Returns: s3 uri
Expand All @@ -875,27 +901,34 @@ def _stage_user_code_in_s3(self):

if self.code_location is None and local_mode:
code_bucket = self.sagemaker_session.default_bucket()
code_s3_prefix = '{}/source'.format(self._current_job_name)
code_s3_prefix = '{}/{}'.format(self._current_job_name, s3_location)
kms_key = None

elif self.code_location is None:
code_bucket, _ = parse_s3_url(self.output_path)
code_s3_prefix = '{}/source'.format(self._current_job_name)
code_s3_prefix = '{}/{}'.format(self._current_job_name, s3_location)
kms_key = self.output_kms_key
else:
code_bucket, key_prefix = parse_s3_url(self.code_location)
code_s3_prefix = '/'.join(filter(None, [key_prefix, self._current_job_name, 'source']))
code_s3_prefix = '/'.join(filter(None, [key_prefix, self._current_job_name, s3_location]))

output_bucket, _ = parse_s3_url(self.output_path)
kms_key = self.output_kms_key if code_bucket == output_bucket else None

return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
bucket=code_bucket,
s3_key_prefix=code_s3_prefix,
script=self.entry_point,
directory=self.source_dir,
dependencies=self.dependencies,
kms_key=kms_key)
if tar:
return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
bucket=code_bucket,
s3_key_prefix=code_s3_prefix,
script=self.entry_point,
directory=self.source_dir,
dependencies=self.dependencies,
kms_key=kms_key)
else:
return upload_file(session=self.sagemaker_session.boto_session,
bucket=code_bucket,
s3_key_prefix=code_s3_prefix,
file=self.entry_point,
kms_key=kms_key)

def _model_source_dir(self):
"""Get the appropriate value to pass as source_dir to model constructor on deploying
Expand Down
13 changes: 13 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script,
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)


def upload_file(session, bucket, s3_key_prefix, file, kms_key=None):
file_name = os.path.basename(file)
key = '{}/{}'.format(s3_key_prefix, file_name)
if kms_key:
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key}
else:
extra_args = None

session.resource('s3').Object(bucket, key).upload_file(file, ExtraArgs=extra_args)

return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=file_name)


def _list_files_to_compress(script, directory):
if directory is None:
return [script]
Expand Down
44 changes: 27 additions & 17 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,21 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
vpc_config = estimator.get_vpc_config()

model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name,
validate_uri)
model_channel = _Job._prepare_channel(input_config, estimator.model_uri, estimator.model_channel_name,
validate_uri, content_type='application/x-sagemaker-model',
input_mode='File')
if model_channel:
input_config = [] if input_config is None else input_config
input_config.append(model_channel)

if estimator.enable_network_isolation():
code_channel = _Job._prepare_channel(input_config, estimator.code_uri, estimator.code_channel_name,
validate_uri)

if code_channel:
input_config = [] if input_config is None else input_config
input_config.append(code_channel)

return {'input_config': input_config,
'role': role,
'output_config': output_config,
Expand Down Expand Up @@ -110,16 +119,16 @@ def _convert_input_to_channel(channel_name, channel_s3_input):
return channel_config

@staticmethod
def _format_string_uri_input(uri_input, validate_uri=True):
def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, input_mode=None):
if isinstance(uri_input, str) and validate_uri and uri_input.startswith('s3://'):
return s3_input(uri_input)
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
elif isinstance(uri_input, str) and validate_uri and uri_input.startswith('file://'):
return file_input(uri_input)
elif isinstance(uri_input, str) and validate_uri:
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
'"file://"')
raise ValueError('URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
'"file://"'.format(uri_input))
elif isinstance(uri_input, str):
return s3_input(uri_input)
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
elif isinstance(uri_input, s3_input):
return uri_input
elif isinstance(uri_input, file_input):
Expand All @@ -128,21 +137,22 @@ def _format_string_uri_input(uri_input, validate_uri=True):
raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(uri_input))

@staticmethod
def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None, validate_uri=True):
if not model_uri:
def _prepare_channel(input_config, channel_uri=None, channel_name=None, validate_uri=True, content_type=None,
input_mode=None):
if not channel_uri:
return
elif not model_channel_name:
raise ValueError('Expected a pre-trained model channel name if a model URL is specified.')
elif not channel_name:
raise ValueError('Expected a channel name if a channel URI {} is specified'.format(channel_uri))

if input_config:
for channel in input_config:
if channel['ChannelName'] == model_channel_name:
raise ValueError('Duplicate channels not allowed.')
for existing_channel in input_config:
if existing_channel['ChannelName'] == channel_name:
raise ValueError('Duplicate channel {} not allowed.'.format(channel_name))

model_input = _Job._format_model_uri_input(model_uri, validate_uri)
model_channel = _Job._convert_input_to_channel(model_channel_name, model_input)
channel_input = _Job._format_string_uri_input(channel_uri, validate_uri, content_type, input_mode)
channel = _Job._convert_input_to_channel(channel_name, channel_input)

return model_channel
return channel

@staticmethod
def _format_model_uri_input(model_uri, validate_uri=True):
Expand Down
25 changes: 25 additions & 0 deletions tests/integ/test_sklearn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,31 @@ def test_training_with_additional_hyperparameters(sagemaker_session, sklearn_ful
return sklearn.latest_training_job.name


@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.")
def test_training_with_network_isolation(sagemaker_session, sklearn_full_version):
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'mnist.py')
data_path = os.path.join(DATA_DIR, 'sklearn_mnist')

sklearn = SKLearn(entry_point=script_path,
role='SageMakerRole',
train_instance_type="ml.c4.xlarge",
framework_version=sklearn_full_version,
py_version=PYTHON_VERSION,
sagemaker_session=sagemaker_session,
hyperparameters={'epochs': 1},
enable_network_isolation=True)

train_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
key_prefix='integ-test-data/sklearn_mnist/train')
test_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
key_prefix='integ-test-data/sklearn_mnist/test')
job_name = unique_name_from_base('test-sklearn-hp')

sklearn.fit({'train': train_input, 'test': test_input}, job_name=job_name)
return sklearn.latest_training_job.name


@pytest.mark.canary_quick
@pytest.mark.regional_testing
@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.")
Expand Down
Loading