Skip to content

Commit c758362

Browse files
authored
Allow Local Serving of Models in S3 (#217)
If a model is located in S3, download and extract it and then start serving as usual. By overriding the SageMaker Session on any Model Object, local serving is possible on pre-trained models.
1 parent 2e2e397 commit c758362

File tree

7 files changed

+202
-24
lines changed

7 files changed

+202
-24
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ CHANGELOG
55
1.4.2
66
=====
77

8-
* bug-fix: Unit Tests: Improve unit test runtime
98
* bug-fix: Estimators: Fix attach for LDA
109
* bug-fix: Estimators: allow code_location to have no key prefix
1110
* bug-fix: Local Mode: Fix s3 training data download when there is a trailing slash
11+
* feature: Allow Local Serving of Models in S3
1212

1313

1414
1.4.1

src/sagemaker/local/image.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import string
2525
import subprocess
2626
import sys
27+
import tarfile
2728
import tempfile
2829
from fcntl import fcntl, F_GETFL, F_SETFL
2930
from six.moves.urllib.parse import urlparse
@@ -137,7 +138,7 @@ def serve(self, primary_container):
137138
Args:
138139
primary_container (dict): dictionary containing the container runtime settings
139140
for serving. Expected keys:
140-
- 'ModelDataUrl' pointing to a local file
141+
- 'ModelDataUrl' pointing to a file or s3:// location.
141142
- 'Environment' a dictionary of environment variables to be passed to the hosting container.
142143
143144
"""
@@ -147,22 +148,17 @@ def serve(self, primary_container):
147148
logger.info('creating hosting dir in {}'.format(self.container_root))
148149

149150
model_dir = primary_container['ModelDataUrl']
150-
if not model_dir.lower().startswith("s3://"):
151-
for h in self.hosts:
152-
host_dir = os.path.join(self.container_root, h)
153-
os.makedirs(host_dir)
154-
shutil.copytree(model_dir, os.path.join(self.container_root, h, 'model'))
155-
151+
volumes = self._prepare_serving_volumes(model_dir)
156152
env_vars = ['{}={}'.format(k, v) for k, v in primary_container['Environment'].items()]
157153

158-
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
159-
160154
# If the user script was passed as a file:// mount it to the container.
161-
script_dir = primary_container['Environment'][sagemaker.estimator.DIR_PARAM_NAME.upper()]
162-
parsed_uri = urlparse(script_dir)
163-
volumes = []
164-
if parsed_uri.scheme == 'file':
165-
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
155+
if sagemaker.estimator.DIR_PARAM_NAME.upper() in primary_container['Environment']:
156+
script_dir = primary_container['Environment'][sagemaker.estimator.DIR_PARAM_NAME.upper()]
157+
parsed_uri = urlparse(script_dir)
158+
if parsed_uri.scheme == 'file':
159+
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
160+
161+
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
166162

167163
self._generate_compose_file('serve',
168164
additional_env_vars=env_vars,
@@ -278,9 +274,20 @@ def _download_folder(self, bucket_name, prefix, target):
278274
pass
279275
obj.download_file(file_path)
280276

277+
def _download_file(self, bucket_name, path, target):
278+
path = path.lstrip('/')
279+
boto_session = self.sagemaker_session.boto_session
280+
281+
s3 = boto_session.resource('s3')
282+
bucket = s3.Bucket(bucket_name)
283+
bucket.download_file(path, target)
284+
281285
def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters):
282286
shared_dir = os.path.join(self.container_root, 'shared')
287+
model_dir = os.path.join(self.container_root, 'model')
283288
volumes = []
289+
290+
volumes.append(_Volume(model_dir, '/opt/ml/model'))
284291
# Set up the channels for the containers. For local data we will
285292
# mount the local directory to the container. For S3 Data we will download the S3 data
286293
# first.
@@ -321,6 +328,32 @@ def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters
321328

322329
return volumes
323330

331+
def _prepare_serving_volumes(self, model_location):
332+
volumes = []
333+
host = self.hosts[0]
334+
# Make the model available to the container. If this is a local file just mount it to
335+
# the container as a volume. If it is an S3 location download it and extract the tar file.
336+
host_dir = os.path.join(self.container_root, host)
337+
os.makedirs(host_dir)
338+
339+
if model_location.startswith('s3'):
340+
container_model_dir = os.path.join(self.container_root, host, 'model')
341+
os.makedirs(container_model_dir)
342+
343+
parsed_uri = urlparse(model_location)
344+
filename = os.path.basename(parsed_uri.path)
345+
tar_location = os.path.join(container_model_dir, filename)
346+
self._download_file(parsed_uri.netloc, parsed_uri.path, tar_location)
347+
348+
if tarfile.is_tarfile(tar_location):
349+
with tarfile.open(tar_location) as tar:
350+
tar.extractall(path=container_model_dir)
351+
volumes.append(_Volume(container_model_dir, '/opt/ml/model'))
352+
else:
353+
volumes.append(_Volume(model_location, '/opt/ml/model'))
354+
355+
return volumes
356+
324357
def _generate_compose_file(self, command, additional_volumes=None, additional_env_vars=None):
325358
"""Writes a config file describing a training/hosting environment.
326359
@@ -452,10 +485,6 @@ def _build_optml_volumes(self, host, subdirs):
452485
"""
453486
volumes = []
454487

455-
# Ensure that model is in the subdirs
456-
if 'model' not in subdirs:
457-
subdirs.add('model')
458-
459488
for subdir in subdirs:
460489
host_dir = os.path.join(self.container_root, host, subdir)
461490
container_dir = '/opt/ml/{}'.format(subdir)

src/sagemaker/model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
import sagemaker
1818

19+
from sagemaker.local import LocalSession
1920
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url
2021
from sagemaker.session import Session
2122
from sagemaker.utils import name_from_image, get_config_value
2223

2324

2425
class Model(object):
25-
"""An SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
26+
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
2627

2728
def __init__(self, model_data, image, role, predictor_cls=None, env=None, name=None, sagemaker_session=None):
2829
"""Initialize an SageMaker ``Model``.
@@ -48,7 +49,7 @@ def __init__(self, model_data, image, role, predictor_cls=None, env=None, name=N
4849
self.predictor_cls = predictor_cls
4950
self.env = env or {}
5051
self.name = name
51-
self.sagemaker_session = sagemaker_session or Session()
52+
self.sagemaker_session = sagemaker_session
5253
self._model_name = None
5354

5455
def prepare_container_def(self, instance_type):
@@ -86,6 +87,12 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None):
8687
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
8788
the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None.
8889
"""
90+
if not self.sagemaker_session:
91+
if instance_type in ('local', 'local_gpu'):
92+
self.sagemaker_session = LocalSession()
93+
else:
94+
self.sagemaker_session = Session()
95+
8996
container_def = self.prepare_container_def(instance_type)
9097
model_name = self.name or name_from_image(container_def['Image'])
9198
self.sagemaker_session.create_model(model_name, self.role, container_def)

src/sagemaker/predictor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def predict(self, data):
9393
response_body.close()
9494
return data
9595

96+
def delete_endpoint(self):
97+
"""Delete the Amazon SageMaker endpoint backing this predictor.
98+
"""
99+
self.sagemaker_session.delete_endpoint(self.endpoint)
100+
96101

97102
class _CsvSerializer(object):
98103
def __init__(self):

tests/integ/test_local_mode.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818

1919
import boto3
2020
import numpy
21+
import pytest
2122

2223
from sagemaker.local import LocalSession, LocalSagemakerRuntimeClient, LocalSagemakerClient
23-
from sagemaker.mxnet import MXNet
24+
from sagemaker.mxnet import MXNet, MXNetModel
2425
from sagemaker.tensorflow import TensorFlow
26+
from sagemaker.fw_utils import tar_and_upload_dir
2527
from tests.integ import DATA_DIR
2628
from tests.integ.timeout import timeout
2729

@@ -54,6 +56,25 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
5456
self.local_mode = True
5557

5658

59+
@pytest.fixture(scope='module')
60+
def mxnet_model(sagemaker_local_session):
61+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
62+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
63+
64+
mx = MXNet(entry_point=script_path, role='SageMakerRole',
65+
train_instance_count=1, train_instance_type='local',
66+
sagemaker_session=sagemaker_local_session)
67+
68+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
69+
key_prefix='integ-test-data/mxnet_mnist/train')
70+
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
71+
key_prefix='integ-test-data/mxnet_mnist/test')
72+
73+
mx.fit({'train': train_input, 'test': test_input})
74+
model = mx.create_model(1)
75+
return model
76+
77+
5778
def test_tf_local_mode(tf_full_version, sagemaker_local_session):
5879
local_mode_lock_fd = open(LOCK_PATH, 'w')
5980
local_mode_lock = local_mode_lock_fd.fileno()
@@ -230,6 +251,57 @@ def test_tf_local_data_local_script():
230251
fcntl.lockf(local_mode_lock, fcntl.LOCK_UN)
231252

232253

254+
def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model):
255+
local_mode_lock_fd = open(LOCK_PATH, 'w')
256+
local_mode_lock = local_mode_lock_fd.fileno()
257+
258+
model_data = mxnet_model.model_data
259+
boto_session = sagemaker_local_session.boto_session
260+
default_bucket = sagemaker_local_session.default_bucket()
261+
uploaded_data = tar_and_upload_dir(boto_session, default_bucket,
262+
'test_mxnet_local_mode', '', model_data)
263+
264+
s3_model = MXNetModel(model_data=uploaded_data.s3_prefix, role='SageMakerRole',
265+
entry_point=mxnet_model.entry_point, image=mxnet_model.image,
266+
sagemaker_session=sagemaker_local_session)
267+
268+
predictor = None
269+
try:
270+
# Since Local Mode uses the same port for serving, we need a lock in order
271+
# to allow concurrent test execution. The serving test is really fast so it still
272+
# makes sense to allow this behavior.
273+
fcntl.lockf(local_mode_lock, fcntl.LOCK_EX)
274+
predictor = s3_model.deploy(initial_instance_count=1, instance_type='local')
275+
data = numpy.zeros(shape=(1, 1, 28, 28))
276+
predictor.predict(data)
277+
finally:
278+
if predictor:
279+
predictor.delete_endpoint()
280+
time.sleep(5)
281+
fcntl.lockf(local_mode_lock, fcntl.LOCK_UN)
282+
283+
284+
def test_local_mode_serving_from_local_model(sagemaker_local_session, mxnet_model):
285+
local_mode_lock_fd = open(LOCK_PATH, 'w')
286+
local_mode_lock = local_mode_lock_fd.fileno()
287+
predictor = None
288+
289+
try:
290+
# Since Local Mode uses the same port for serving, we need a lock in order
291+
# to allow concurrent test execution. The serving test is really fast so it still
292+
# makes sense to allow this behavior.
293+
fcntl.lockf(local_mode_lock, fcntl.LOCK_EX)
294+
mxnet_model.sagemaker_session = sagemaker_local_session
295+
predictor = mxnet_model.deploy(initial_instance_count=1, instance_type='local')
296+
data = numpy.zeros(shape=(1, 1, 28, 28))
297+
predictor.predict(data)
298+
finally:
299+
if predictor:
300+
predictor.delete_endpoint()
301+
time.sleep(5)
302+
fcntl.lockf(local_mode_lock, fcntl.LOCK_UN)
303+
304+
233305
def test_mxnet_local_mode(sagemaker_local_session):
234306
local_mode_lock_fd = open(LOCK_PATH, 'w')
235307
local_mode_lock = local_mode_lock_fd.fileno()

tests/unit/test_image.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pytest
2121
import yaml
22-
from mock import call, patch, Mock
22+
from mock import call, patch, Mock, MagicMock
2323

2424
import sagemaker
2525
from sagemaker.local.image import _SageMakerContainer
@@ -338,6 +338,42 @@ def test_serve_local_code(up, copy, copytree, tmpdir, sagemaker_session):
338338
assert '%s:/opt/ml/code' % '/tmp/code' in volumes
339339

340340

341+
@patch('sagemaker.local.image._SageMakerContainer._download_file')
342+
@patch('tarfile.is_tarfile')
343+
@patch('tarfile.open', MagicMock())
344+
@patch('os.makedirs', Mock())
345+
def test_prepare_serving_volumes_with_s3_model(is_tarfile, _download_file, sagemaker_session):
346+
347+
sagemaker_container = _SageMakerContainer('local', 1, 'some-image', sagemaker_session=sagemaker_session)
348+
sagemaker_container.container_root = '/tmp/container_root'
349+
container_model_dir = os.path.join('/tmp/container_root/', sagemaker_container.hosts[0], 'model')
350+
351+
is_tarfile.return_value = True
352+
353+
volumes = sagemaker_container._prepare_serving_volumes('s3://bucket/my_model.tar.gz')
354+
355+
tar_location = os.path.join(container_model_dir, 'my_model.tar.gz')
356+
_download_file.assert_called_with('bucket', '/my_model.tar.gz', tar_location)
357+
is_tarfile.assert_called_with(tar_location)
358+
359+
assert len(volumes) == 1
360+
assert volumes[0].container_dir == '/opt/ml/model'
361+
assert volumes[0].host_dir == container_model_dir
362+
363+
364+
@patch('os.makedirs', Mock())
365+
def test_prepare_serving_volumes_with_local_model(sagemaker_session):
366+
367+
sagemaker_container = _SageMakerContainer('local', 1, 'some-image', sagemaker_session=sagemaker_session)
368+
sagemaker_container.container_root = '/tmp/container_root'
369+
370+
volumes = sagemaker_container._prepare_serving_volumes('/path/to/my_model')
371+
372+
assert len(volumes) == 1
373+
assert volumes[0].container_dir == '/opt/ml/model'
374+
assert volumes[0].host_dir == '/path/to/my_model'
375+
376+
341377
@patch('os.makedirs')
342378
def test_download_folder(makedirs):
343379
boto_mock = Mock(name='boto_session')
@@ -377,6 +413,19 @@ def test_download_folder(makedirs):
377413
obj_mock.download_file.assert_has_calls(calls)
378414

379415

416+
def test_download_file():
417+
boto_mock = Mock(name='boto_session')
418+
boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'}
419+
bucket_mock = Mock()
420+
boto_mock.resource('s3').Bucket.return_value = bucket_mock
421+
session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
422+
423+
sagemaker_container = _SageMakerContainer('local', 2, 'my-image', sagemaker_session=session)
424+
sagemaker_container._download_file(BUCKET_NAME, '/prefix/path/file.tar.gz', '/tmp/file.tar.gz')
425+
426+
bucket_mock.download_file.assert_called_with('prefix/path/file.tar.gz', '/tmp/file.tar.gz')
427+
428+
380429
def test_ecr_login_non_ecr():
381430
session_mock = Mock()
382431
sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu')

tests/unit/test_model.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sagemaker.predictor import RealTimePredictor
1717
import os
1818
import pytest
19-
from mock import Mock, patch
19+
from mock import MagicMock, Mock, patch
2020

2121
MODEL_DATA = "s3://bucket/model.tar.gz"
2222
MODEL_IMAGE = "mi"
@@ -115,3 +115,19 @@ def test_deploy_endpoint_name(tfo, time, sagemaker_session):
115115
'InstanceType': INSTANCE_TYPE,
116116
'InitialInstanceCount': 55,
117117
'VariantName': 'AllTraffic'}])
118+
119+
120+
@patch('sagemaker.model.Session')
121+
@patch('sagemaker.model.LocalSession')
122+
@patch('tarfile.open', MagicMock())
123+
def test_deploy_creates_correct_session(local_session, session):
124+
125+
# We expect a LocalSession when deploying to instance_type = 'local'
126+
model = DummyFrameworkModel(sagemaker_session=None)
127+
model.deploy(endpoint_name='blah', instance_type='local', initial_instance_count=1)
128+
assert model.sagemaker_session == local_session.return_value
129+
130+
# We expect a real Session when deploying to instance_type != local/local_gpu
131+
model = DummyFrameworkModel(sagemaker_session=None)
132+
model.deploy(endpoint_name='remote_endpoint', instance_type='ml.m4.4xlarge', initial_instance_count=2)
133+
assert model.sagemaker_session == session.return_value

0 commit comments

Comments
 (0)