Skip to content

Commit 5d6aeb4

Browse files
authored
fix: honor source_dir from S3 (#811)
1 parent ccad8c0 commit 5d6aeb4

File tree

7 files changed

+92
-8
lines changed

7 files changed

+92
-8
lines changed

src/sagemaker/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
323323
tmp_model_dir = os.path.join(tmp, 'model')
324324
os.mkdir(tmp_model_dir)
325325

326-
model_from_s3 = model_uri.startswith('s3://')
326+
model_from_s3 = model_uri.lower().startswith('s3://')
327327
if model_from_s3:
328328
local_model_path = os.path.join(tmp, 'tar_file')
329329
download_file_from_url(model_uri, local_model_path, sagemaker_session)
@@ -340,7 +340,14 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
340340
if os.path.exists(code_dir):
341341
shutil.rmtree(code_dir, ignore_errors=True)
342342

343-
if source_directory:
343+
if source_directory and source_directory.lower().startswith('s3://'):
344+
local_code_path = os.path.join(tmp, 'local_code.tar.gz')
345+
download_file_from_url(source_directory, local_code_path, sagemaker_session)
346+
347+
with tarfile.open(name=local_model_path, mode='r:gz') as t:
348+
t.extractall(path=code_dir)
349+
350+
elif source_directory:
344351
shutil.copytree(source_directory, code_dir)
345352
else:
346353
os.mkdir(code_dir)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
asset-file-contents
Binary file not shown.
Binary file not shown.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
"""Exports a toy TensorFlow model.
15+
Exports a TensorFlow model to /opt/ml/model/
16+
This graph calculates,
17+
y = a*x + b
18+
where a and b are variables with a=0.5 and b=2.
19+
"""
20+
import json
21+
import shutil
22+
23+
24+
def save_model():
25+
shutil.copytree('/opt/ml/code/123', '/opt/ml/model/123')
26+
27+
28+
def input_handler(data, context):
29+
data = json.loads(data.read().decode('utf-8'))
30+
new_values = [x + 1 for x in data['instances']]
31+
dumps = json.dumps({'instances': new_values})
32+
return dumps
33+
34+
35+
def output_handler(data, context):
36+
response_content_type = context.accept_header
37+
prediction = data.content
38+
return prediction, response_content_type
39+
40+
41+
if __name__ == "__main__":
42+
save_model()
43+

tests/integ/test_tf_script_mode.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@
2424
from sagemaker.utils import unique_name_from_base
2525

2626
import tests.integ
27+
from tests.integ import timeout
2728

2829
ROLE = 'SageMakerRole'
2930

30-
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'data', 'tensorflow_mnist')
31-
SCRIPT = os.path.join(RESOURCE_PATH, 'mnist.py')
31+
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'data')
32+
MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, 'tensorflow_mnist')
33+
TFS_RESOURCE_PATH = os.path.join(RESOURCE_PATH, 'tfs', 'tfs-test-entrypoint-with-handler')
34+
35+
SCRIPT = os.path.join(MNIST_RESOURCE_PATH, 'mnist.py')
3236
PARAMETER_SERVER_DISTRIBUTION = {'parameter_server': {'enabled': True}}
3337
MPI_DISTRIBUTION = {'mpi': {'enabled': True}}
3438
TAGS = [{'Key': 'some-key', 'Value': 'some-value'}]
@@ -57,7 +61,7 @@ def test_mnist(sagemaker_session, instance_type):
5761
metric_definitions=[
5862
{'Name': 'train:global_steps', 'Regex': r'global_step\/sec:\s(.*)'}])
5963
inputs = estimator.sagemaker_session.upload_data(
60-
path=os.path.join(RESOURCE_PATH, 'data'),
64+
path=os.path.join(MNIST_RESOURCE_PATH, 'data'),
6165
key_prefix='scriptmode/mnist')
6266

6367
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
@@ -88,7 +92,7 @@ def test_server_side_encryption(sagemaker_session):
8892
output_kms_key=kms_key)
8993

9094
inputs = estimator.sagemaker_session.upload_data(
91-
path=os.path.join(RESOURCE_PATH, 'data'),
95+
path=os.path.join(MNIST_RESOURCE_PATH, 'data'),
9296
key_prefix='scriptmode/mnist')
9397

9498
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
@@ -110,7 +114,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
110114
framework_version=TensorFlow.LATEST_VERSION,
111115
distributions=PARAMETER_SERVER_DISTRIBUTION)
112116
inputs = estimator.sagemaker_session.upload_data(
113-
path=os.path.join(RESOURCE_PATH, 'data'),
117+
path=os.path.join(MNIST_RESOURCE_PATH, 'data'),
114118
key_prefix='scriptmode/distributed_mnist')
115119

116120
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
@@ -129,7 +133,7 @@ def test_mnist_async(sagemaker_session):
129133
framework_version=TensorFlow.LATEST_VERSION,
130134
tags=TAGS)
131135
inputs = estimator.sagemaker_session.upload_data(
132-
path=os.path.join(RESOURCE_PATH, 'data'),
136+
path=os.path.join(MNIST_RESOURCE_PATH, 'data'),
133137
key_prefix='scriptmode/mnist')
134138
estimator.fit(inputs=inputs, wait=False, job_name=unique_name_from_base('test-tf-sm-async'))
135139
training_job_name = estimator.latest_training_job.name
@@ -150,6 +154,35 @@ def test_mnist_async(sagemaker_session):
150154
estimator.latest_training_job.name, TAGS)
151155

152156

157+
@pytest.mark.skipif(tests.integ.PYTHON_VERSION != 'py3',
158+
reason="Script Mode tests are only configured to run with Python 3")
159+
def test_deploy_with_input_handlers(sagemaker_session, instance_type):
160+
estimator = TensorFlow(entry_point='inference.py',
161+
source_dir=TFS_RESOURCE_PATH,
162+
role=ROLE,
163+
train_instance_count=1,
164+
train_instance_type=instance_type,
165+
sagemaker_session=sagemaker_session,
166+
py_version='py3',
167+
framework_version=TensorFlow.LATEST_VERSION,
168+
tags=TAGS)
169+
170+
estimator.fit(job_name=unique_name_from_base('test-tf-tfs-deploy'))
171+
172+
endpoint_name = estimator.latest_training_job.name
173+
174+
with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
175+
176+
predictor = estimator.deploy(initial_instance_count=1, instance_type=instance_type,
177+
endpoint_name=endpoint_name)
178+
179+
input_data = {'instances': [1.0, 2.0, 5.0]}
180+
expected_result = {'predictions': [4.0, 4.5, 6.0]}
181+
182+
result = predictor.predict(input_data)
183+
assert expected_result == result
184+
185+
153186
def _assert_s3_files_exist(s3_url, files):
154187
parsed_url = urlparse(s3_url)
155188
s3 = boto3.client('s3')

0 commit comments

Comments
 (0)