Skip to content

Commit b71d8fc

Browse files
author
Dan
authored
add local integration tests (#10)
1 parent f7ca908 commit b71d8fc

29 files changed

+642
-1
lines changed

test/__init__.py

Whitespace-only changes.

test/conftest.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import logging
16+
import os
17+
18+
import boto3
19+
import pytest
20+
from sagemaker import LocalSession, Session
21+
from sagemaker.mxnet import MXNet
22+
23+
logger = logging.getLogger(__name__)
24+
logging.getLogger('boto').setLevel(logging.INFO)
25+
logging.getLogger('botocore').setLevel(logging.INFO)
26+
logging.getLogger('factory.py').setLevel(logging.INFO)
27+
logging.getLogger('auth.py').setLevel(logging.INFO)
28+
logging.getLogger('connectionpool.py').setLevel(logging.INFO)
29+
30+
SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
31+
32+
33+
def pytest_addoption(parser):
34+
parser.addoption('--docker-base-name', default='preprod-mxnet-serving')
35+
parser.addoption('--region', default='us-west-2')
36+
parser.addoption('--framework-version', default=MXNet.LATEST_VERSION)
37+
parser.addoption('--py-version', default='3', choices=['2', '3'])
38+
parser.addoption('--processor', default='cpu', choices=['gpu', 'cpu'])
39+
parser.addoption('--aws-id', default=None)
40+
parser.addoption('--instance-type', default=None)
41+
parser.addoption('--accelerator-type', default=None)
42+
# If not specified, will default to {framework-version}-{processor}-py{py-version}
43+
parser.addoption('--tag', default=None)
44+
45+
46+
@pytest.fixture(scope='session')
47+
def docker_base_name(request):
48+
return request.config.getoption('--docker-base-name')
49+
50+
51+
@pytest.fixture(scope='session')
52+
def region(request):
53+
return request.config.getoption('--region')
54+
55+
56+
@pytest.fixture(scope='session')
57+
def framework_version(request):
58+
return request.config.getoption('--framework-version')
59+
60+
61+
@pytest.fixture(scope='session')
62+
def py_version(request):
63+
return int(request.config.getoption('--py-version'))
64+
65+
66+
@pytest.fixture(scope='session')
67+
def processor(request):
68+
return request.config.getoption('--processor')
69+
70+
71+
@pytest.fixture(scope='session')
72+
def aws_id(request):
73+
return request.config.getoption('--aws-id')
74+
75+
76+
@pytest.fixture(scope='session')
77+
def tag(request, framework_version, processor, py_version):
78+
provided_tag = request.config.getoption('--tag')
79+
default_tag = '{}-{}-py{}'.format(framework_version, processor, py_version)
80+
return provided_tag if provided_tag is not None else default_tag
81+
82+
83+
@pytest.fixture(scope='session')
84+
def instance_type(request, processor):
85+
provided_instance_type = request.config.getoption('--instance-type')
86+
default_instance_type = 'ml.c4.xlarge' if processor == 'cpu' else 'ml.p2.xlarge'
87+
return provided_instance_type if provided_instance_type is not None else default_instance_type
88+
89+
90+
@pytest.fixture(scope='session')
91+
def accelerator_type(request):
92+
return request.config.getoption('--accelerator-type')
93+
94+
95+
@pytest.fixture(scope='session')
96+
def docker_image(docker_base_name, tag):
97+
return '{}:{}'.format(docker_base_name, tag)
98+
99+
100+
@pytest.fixture(scope='session')
101+
def ecr_image(aws_id, docker_base_name, tag, region):
102+
return '{}.dkr.ecr.{}.amazonaws.com/{}:{}'.format(aws_id, region, docker_base_name, tag)
103+
104+
105+
@pytest.fixture(scope='session')
106+
def sagemaker_session(region):
107+
return Session(boto_session=boto3.Session(region_name=region))
108+
109+
110+
@pytest.fixture(scope='session')
111+
def sagemaker_local_session(region):
112+
return LocalSession(boto_session=boto3.Session(region_name=region))
113+
114+
115+
@pytest.fixture(scope='session')
116+
def local_instance_type(processor):
117+
return 'local' if processor == 'cpu' else 'local_gpu'

test/integration/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
RESOURCE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'resources'))
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from contextlib import contextmanager
16+
import fcntl
17+
import os
18+
import tarfile
19+
import time
20+
21+
from test.integration import RESOURCE_PATH
22+
23+
LOCK_PATH = os.path.join(RESOURCE_PATH, 'local_mode_lock')
24+
25+
26+
@contextmanager
27+
def lock():
28+
# Since Local Mode uses the same port for serving, we need a lock in order
29+
# to allow concurrent test execution.
30+
local_mode_lock_fd = open(LOCK_PATH, 'w')
31+
local_mode_lock = local_mode_lock_fd.fileno()
32+
33+
fcntl.lockf(local_mode_lock, fcntl.LOCK_EX)
34+
35+
try:
36+
yield
37+
finally:
38+
time.sleep(5)
39+
fcntl.lockf(local_mode_lock, fcntl.LOCK_UN)
40+
41+
42+
def assert_output_files_exist(output_path, directory, files):
43+
with tarfile.open(os.path.join(output_path, '{}.tar.gz'.format(directory))) as tar:
44+
for f in files:
45+
tar.getmember(f)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
import pytest
18+
import requests
19+
from sagemaker.mxnet.model import MXNetModel
20+
21+
import local_mode_utils
22+
from test.integration import RESOURCE_PATH
23+
24+
DEFAULT_HANDLER_PATH = os.path.join(RESOURCE_PATH, 'default_handlers')
25+
MODEL_PATH = os.path.join(DEFAULT_HANDLER_PATH, 'model')
26+
SCRIPT_PATH = os.path.join(MODEL_PATH, 'code', 'empty_module.py')
27+
28+
29+
@pytest.fixture(scope='module')
30+
def predictor(docker_image, sagemaker_local_session, local_instance_type):
31+
model = MXNetModel('file://{}'.format(MODEL_PATH),
32+
'SageMakerRole',
33+
SCRIPT_PATH,
34+
image=docker_image,
35+
sagemaker_session=sagemaker_local_session)
36+
37+
with local_mode_utils.lock():
38+
try:
39+
predictor = model.deploy(1, local_instance_type)
40+
yield predictor
41+
finally:
42+
sagemaker_local_session.delete_endpoint(model.endpoint_name)
43+
44+
45+
def test_default_model_fn(predictor):
46+
input = [[1, 2]]
47+
output = predictor.predict(input)
48+
assert [[4.9999918937683105]] == output
49+
50+
51+
def test_default_model_fn_via_requests(predictor):
52+
r = requests.post('http://localhost:8080/invocations', json=[[1, 2]])
53+
assert [[4.9999918937683105]] == r.json()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
18+
from sagemaker.mxnet.model import MXNetModel
19+
20+
import local_mode_utils
21+
from test.integration import RESOURCE_PATH
22+
23+
GLUON_PATH = os.path.join(RESOURCE_PATH, 'gluon_hosting')
24+
MODEL_PATH = os.path.join(GLUON_PATH, 'model')
25+
SCRIPT_PATH = os.path.join(MODEL_PATH, 'code', 'gluon.py')
26+
27+
28+
# The image should support serving Gluon-created models.
29+
def test_gluon_hosting(docker_image, sagemaker_local_session, local_instance_type):
30+
model = MXNetModel('file://{}'.format(MODEL_PATH),
31+
'SageMakerRole',
32+
SCRIPT_PATH,
33+
image=docker_image,
34+
sagemaker_session=sagemaker_local_session)
35+
36+
with open(os.path.join(RESOURCE_PATH, 'mnist_images', '04.json'), 'r') as f:
37+
input = json.load(f)
38+
39+
with local_mode_utils.lock():
40+
try:
41+
predictor = model.deploy(1, local_instance_type)
42+
output = predictor.predict(input)
43+
assert [4.0] == output
44+
finally:
45+
sagemaker_local_session.delete_endpoint(model.endpoint_name)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
18+
from sagemaker.mxnet.model import MXNetModel
19+
20+
import local_mode_utils
21+
from test.integration import RESOURCE_PATH
22+
23+
HOSTING_RESOURCE_PATH = os.path.join(RESOURCE_PATH, 'dummy_hosting')
24+
MODEL_PATH = os.path.join(HOSTING_RESOURCE_PATH, 'code')
25+
SCRIPT_PATH = os.path.join(HOSTING_RESOURCE_PATH, 'code', 'dummy_hosting_module.py')
26+
27+
28+
# The image should use the model_fn and transform_fn defined
29+
# in the user-provided script when serving.
30+
def test_hosting(docker_image, sagemaker_local_session, local_instance_type):
31+
model = MXNetModel('file://{}'.format(MODEL_PATH),
32+
'SageMakerRole',
33+
SCRIPT_PATH,
34+
image=docker_image,
35+
sagemaker_session=sagemaker_local_session)
36+
37+
input = json.dumps({'some': 'json'})
38+
39+
with local_mode_utils.lock():
40+
try:
41+
predictor = model.deploy(1, local_instance_type)
42+
output = predictor.predict(input)
43+
assert input == output
44+
finally:
45+
sagemaker_local_session.delete_endpoint(model.endpoint_name)

test/integration/local/test_onnx.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
import numpy
18+
from sagemaker.mxnet import MXNetModel
19+
20+
import local_mode_utils
21+
from test.integration import RESOURCE_PATH
22+
23+
ONNX_PATH = os.path.join(RESOURCE_PATH, 'onnx')
24+
MODEL_PATH = os.path.join(ONNX_PATH, 'onnx_model')
25+
SCRIPT_PATH = os.path.join(MODEL_PATH, 'code', 'onnx_import.py')
26+
27+
28+
def test_onnx_import(docker_image, sagemaker_local_session, local_instance_type):
29+
model = MXNetModel('file://{}'.format(MODEL_PATH),
30+
'SageMakerRole',
31+
SCRIPT_PATH,
32+
image=docker_image,
33+
sagemaker_session=sagemaker_local_session)
34+
35+
input = numpy.zeros(shape=(1, 1, 28, 28))
36+
37+
with local_mode_utils.lock():
38+
try:
39+
predictor = model.deploy(1, local_instance_type)
40+
output = predictor.predict(input)
41+
finally:
42+
sagemaker_local_session.delete_endpoint(model.endpoint_name)
43+
44+
# Check that there is a probability for each possible class in the prediction
45+
assert len(output[0]) == 10
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
14+
# nothing here... we are testing default model loading and handlers
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[{"shape": [1, 2], "name": "data"}]

0 commit comments

Comments
 (0)