Skip to content

Commit 0c5e32b

Browse files
author
Yue Tu
committed
merge aws master
2 parents e6a01f0 + f1d34ad commit 0c5e32b

File tree

6 files changed

+256
-73
lines changed

6 files changed

+256
-73
lines changed

doc/overview.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,33 @@ Likewise, when you create ``Transformer`` from the ``Estimator`` using ``transfo
619619
# Transform Job container instances will run in your VPC
620620
mxnet_vpc_transformer.transform('s3://my-bucket/batch-transform-input')
621621
622+
Secure Training with Network Isolation (Internet-Free) Mode
623+
-------------------------------------------------------------------------
624+
You can enable network isolation mode when running training and inference on Amazon SageMaker.
625+
626+
For more information about Amazon SageMaker network isolation mode, see the `SageMaker documentation on network isolation or internet-free mode <https://docs.aws.amazon.com/sagemaker/latest/dg/mkt-algo-model-internet-free.html>`__.
627+
628+
To train a model in network isolation mode, set the optional parameter ``enable_network_isolation`` to ``True`` in any network isolation supported Framework Estimator.
629+
630+
.. code:: python
631+
632+
# set the enable_network_isolation parameter to True
633+
sklearn_estimator = SKLearn('sklearn-train.py',
634+
train_instance_type='ml.m4.xlarge',
635+
framework_version='0.20.0',
636+
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1},
637+
enable_network_isolation=True)
638+
639+
# SageMaker Training Job will in the container without any inbound or outbound network calls during runtime
640+
sklearn_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
641+
'test': 's3://my-data-bucket/path/to/my/test/data'})
642+
643+
When this training job is created, the SageMaker Python SDK will upload the files in ``entry_point``, ``source_dir``, and ``dependencies`` to S3 as a compressed ``sourcedir.tar.gz`` file (``'s3://mybucket/sourcedir.tar.gz'``).
644+
645+
A new training job channel, named ``code``, will be added with that S3 URI. Before the training docker container is initialized, the ``sourcedir.tar.gz`` will be downloaded from S3 to the ML storage volume like any other offline input channel.
646+
647+
Once the training job begins, the training container will look at the offline input ``code`` channel to install dependencies and run the entry script. This isolates the training container, so no inbound or outbound network calls can be made.
648+
622649
623650
FAQ
624651
---

src/sagemaker/estimator.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
118118
self.metric_definitions = metric_definitions
119119
self.model_uri = model_uri
120120
self.model_channel_name = model_channel_name
121+
self.code_uri = None
122+
self.code_channel_name = 'code'
121123

122124
if self.train_instance_type in ('local', 'local_gpu'):
123125
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
@@ -774,10 +776,11 @@ class Framework(EstimatorBase):
774776
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
775777
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
776778
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'
779+
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = '/opt/ml/input/data/code/sourcedir.tar.gz'
777780

778-
def __init__(self, entry_point, source_dir=None, hyperparameters=None,
779-
enable_cloudwatch_metrics=False, container_log_level=logging.INFO, code_location=None,
780-
image_name=None, dependencies=None, git_config=None, **kwargs):
781+
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
782+
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None,
783+
enable_network_isolation=False, **kwargs):
781784
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
782785
783786
Args:
@@ -811,6 +814,7 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None,
811814
the specified commit.
812815
source_dir (str): Path (absolute or relative) to a directory with any other training
813816
source code dependencies aside from the entry point file (default: None). Structure within this
817+
<<<<<<< HEAD
814818
directory are preserved when training on Amazon SageMaker. If 'git_config' is provided,
815819
source_dir should be a relative location to a directory in the Git repo.
816820
Example:
@@ -824,6 +828,24 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None,
824828
825829
and you need 'train.py' as entry point and 'test.py' as training source code as well, you can
826830
assign entry_point='train.py', source_dir='src'.
831+
=======
832+
directory are preserved when training on Amazon SageMaker.
833+
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
834+
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
835+
For convenience, this accepts other types for keys and values, but ``str()`` will be called
836+
to convert them before training.
837+
enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
838+
training jobs. This will be ignored for now and removed in a further release.
839+
container_log_level (int): Log level to use within the container (default: logging.INFO).
840+
Valid values are defined in the Python logging module.
841+
code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
842+
The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
843+
If not specified, the default code location is s3://default_bucket/job-name/. And code file
844+
uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
845+
image_name (str): An alternate image name to use instead of the official Sagemaker image
846+
for the framework. This is useful to run one of the Sagemaker supported frameworks
847+
with an image containing custom dependencies.
848+
>>>>>>> f1d34ad4073f8d856ef9c596b491f8a4cd8ef31f
827849
dependencies (list[str]): A list of paths to directories (absolute or relative) with
828850
any additional libraries that will be exported to the container (default: []).
829851
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
@@ -840,21 +862,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None,
840862
>>> |------ common
841863
>>> |------ virtual-env
842864
843-
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
844-
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
845-
For convenience, this accepts other types for keys and values, but ``str()`` will be called
846-
to convert them before training.
847-
enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
848-
training jobs. This will be ignored for now and removed in a further release.
849-
container_log_level (int): Log level to use within the container (default: logging.INFO).
850-
Valid values are defined in the Python logging module.
851-
code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
852-
The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
853-
If not specified, the default code location is s3://default_bucket/job-name/. And code file
854-
uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
855-
image_name (str): An alternate image name to use instead of the official Sagemaker image
856-
for the framework. This is useful to run one of the Sagemaker supported frameworks
857-
with an image containing custom dependencies.
865+
enable_network_isolation (bool): Specifies whether container will run in network isolation mode. Network
866+
isolation mode restricts the container access to outside networks (such as the internet). The container
867+
does not make any inbound or outbound network calls. If True, a channel named "code" will be created
868+
for any user entry script for training. The user entry script, files in source_dir (if specified), and
869+
dependencies will be uploaded in a tar to S3. Also known as internet-free mode (default: `False`).
858870
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
859871
"""
860872
super(Framework, self).__init__(**kwargs)
@@ -872,9 +884,18 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None,
872884
self.container_log_level = container_log_level
873885
self.code_location = code_location
874886
self.image_name = image_name
887+
self._enable_network_isolation = enable_network_isolation
875888

876889
self._hyperparameters = hyperparameters or {}
877890

891+
def enable_network_isolation(self):
892+
"""Return True if this Estimator can use network isolation to run.
893+
894+
Returns:
895+
bool: Whether this Estimator can use network isolation or not.
896+
"""
897+
return self._enable_network_isolation
898+
878899
def _prepare_for_training(self, job_name=None):
879900
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
880901
@@ -907,6 +928,11 @@ def _prepare_for_training(self, job_name=None):
907928

908929
code_dir = 'file://' + self.source_dir
909930
script = self.entry_point
931+
elif self.enable_network_isolation() and self.entry_point:
932+
self.uploaded_code = self._stage_user_code_in_s3()
933+
code_dir = self.CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH
934+
script = self.uploaded_code.script_name
935+
self.code_uri = self.uploaded_code.s3_prefix
910936
else:
911937
self.uploaded_code = self._stage_user_code_in_s3()
912938
code_dir = self.uploaded_code.s3_prefix
@@ -930,12 +956,12 @@ def _stage_user_code_in_s3(self):
930956

931957
if self.code_location is None and local_mode:
932958
code_bucket = self.sagemaker_session.default_bucket()
933-
code_s3_prefix = '{}/source'.format(self._current_job_name)
959+
code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source')
934960
kms_key = None
935961

936962
elif self.code_location is None:
937963
code_bucket, _ = parse_s3_url(self.output_path)
938-
code_s3_prefix = '{}/source'.format(self._current_job_name)
964+
code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source')
939965
kms_key = self.output_kms_key
940966
else:
941967
code_bucket, key_prefix = parse_s3_url(self.code_location)

src/sagemaker/job.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,21 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
6060
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
6161
vpc_config = estimator.get_vpc_config()
6262

63-
model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name,
64-
validate_uri)
63+
model_channel = _Job._prepare_channel(input_config, estimator.model_uri, estimator.model_channel_name,
64+
validate_uri, content_type='application/x-sagemaker-model',
65+
input_mode='File')
6566
if model_channel:
6667
input_config = [] if input_config is None else input_config
6768
input_config.append(model_channel)
6869

70+
if estimator.enable_network_isolation():
71+
code_channel = _Job._prepare_channel(input_config, estimator.code_uri, estimator.code_channel_name,
72+
validate_uri)
73+
74+
if code_channel:
75+
input_config = [] if input_config is None else input_config
76+
input_config.append(code_channel)
77+
6978
return {'input_config': input_config,
7079
'role': role,
7180
'output_config': output_config,
@@ -110,16 +119,16 @@ def _convert_input_to_channel(channel_name, channel_s3_input):
110119
return channel_config
111120

112121
@staticmethod
113-
def _format_string_uri_input(uri_input, validate_uri=True):
122+
def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, input_mode=None):
114123
if isinstance(uri_input, str) and validate_uri and uri_input.startswith('s3://'):
115-
return s3_input(uri_input)
124+
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
116125
elif isinstance(uri_input, str) and validate_uri and uri_input.startswith('file://'):
117126
return file_input(uri_input)
118127
elif isinstance(uri_input, str) and validate_uri:
119-
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
120-
'"file://"')
128+
raise ValueError('URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
129+
'"file://"'.format(uri_input))
121130
elif isinstance(uri_input, str):
122-
return s3_input(uri_input)
131+
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
123132
elif isinstance(uri_input, s3_input):
124133
return uri_input
125134
elif isinstance(uri_input, file_input):
@@ -128,21 +137,22 @@ def _format_string_uri_input(uri_input, validate_uri=True):
128137
raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(uri_input))
129138

130139
@staticmethod
131-
def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None, validate_uri=True):
132-
if not model_uri:
140+
def _prepare_channel(input_config, channel_uri=None, channel_name=None, validate_uri=True, content_type=None,
141+
input_mode=None):
142+
if not channel_uri:
133143
return
134-
elif not model_channel_name:
135-
raise ValueError('Expected a pre-trained model channel name if a model URL is specified.')
144+
elif not channel_name:
145+
raise ValueError('Expected a channel name if a channel URI {} is specified'.format(channel_uri))
136146

137147
if input_config:
138-
for channel in input_config:
139-
if channel['ChannelName'] == model_channel_name:
140-
raise ValueError('Duplicate channels not allowed.')
148+
for existing_channel in input_config:
149+
if existing_channel['ChannelName'] == channel_name:
150+
raise ValueError('Duplicate channel {} not allowed.'.format(channel_name))
141151

142-
model_input = _Job._format_model_uri_input(model_uri, validate_uri)
143-
model_channel = _Job._convert_input_to_channel(model_channel_name, model_input)
152+
channel_input = _Job._format_string_uri_input(channel_uri, validate_uri, content_type, input_mode)
153+
channel = _Job._convert_input_to_channel(channel_name, channel_input)
144154

145-
return model_channel
155+
return channel
146156

147157
@staticmethod
148158
def _format_model_uri_input(model_uri, validate_uri=True):

tests/integ/test_sklearn_train.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,33 @@ def test_training_with_additional_hyperparameters(sagemaker_session, sklearn_ful
5555
return sklearn.latest_training_job.name
5656

5757

58+
@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.")
59+
def test_training_with_network_isolation(sagemaker_session, sklearn_full_version):
60+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
61+
script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'mnist.py')
62+
data_path = os.path.join(DATA_DIR, 'sklearn_mnist')
63+
64+
sklearn = SKLearn(entry_point=script_path,
65+
role='SageMakerRole',
66+
train_instance_type="ml.c4.xlarge",
67+
framework_version=sklearn_full_version,
68+
py_version=PYTHON_VERSION,
69+
sagemaker_session=sagemaker_session,
70+
hyperparameters={'epochs': 1},
71+
enable_network_isolation=True)
72+
73+
train_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
74+
key_prefix='integ-test-data/sklearn_mnist/train')
75+
test_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
76+
key_prefix='integ-test-data/sklearn_mnist/test')
77+
job_name = unique_name_from_base('test-sklearn-hp')
78+
79+
sklearn.fit({'train': train_input, 'test': test_input}, job_name=job_name)
80+
assert sagemaker_session.sagemaker_client \
81+
.describe_training_job(TrainingJobName=job_name)['EnableNetworkIsolation']
82+
return sklearn.latest_training_job.name
83+
84+
5885
@pytest.mark.canary_quick
5986
@pytest.mark.regional_testing
6087
@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.")

tests/unit/test_fw_utils.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -175,27 +175,26 @@ def test_unoptimized_gpu_family():
175175

176176

177177
def test_tar_and_upload_dir_s3(sagemaker_session):
178-
bucket = 'mybucker'
178+
bucket = 'mybucket'
179179
s3_key_prefix = 'something/source'
180180
script = 'mnist.py'
181181
directory = 's3://m'
182182
result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory)
183+
183184
assert result == fw_utils.UploadedCode('s3://m', 'mnist.py')
184185

185186

186187
@patch('sagemaker.utils')
187188
def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
189+
bucket = 'mybucket'
190+
s3_key_prefix = 'something/source'
191+
script = 'mnist.py'
192+
kms_key = 'kms-key'
193+
result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, kms_key=kms_key)
188194

189-
result = fw_utils.tar_and_upload_dir(sagemaker_session,
190-
'mybucker',
191-
'something/source',
192-
'mnist.py',
193-
kms_key='kms-key')
194-
195-
assert result == fw_utils.UploadedCode('s3://mybucker/something/source/sourcedir.tar.gz',
196-
'mnist.py')
195+
assert result == fw_utils.UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), script)
197196

198-
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'}
197+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key}
199198
obj = sagemaker_session.resource('s3').Object('', '')
200199
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
201200

0 commit comments

Comments
 (0)