|
| 1 | +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). |
| 4 | +# You may not use this file except in compliance with the License. |
| 5 | +# A copy of the License is located at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is distributed |
| 10 | +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either |
| 11 | +# express or implied. See the License for the specific language governing |
| 12 | +# permissions and limitations under the License. |
| 13 | +from __future__ import absolute_import |
| 14 | + |
| 15 | +import json |
| 16 | +import os |
| 17 | +from urllib.parse import urlparse |
| 18 | + |
| 19 | +from sagemaker import utils |
| 20 | +from sagemaker.mxnet.model import MXNetModel |
| 21 | + |
| 22 | +from test.integration import RESOURCE_PATH |
| 23 | +import timeout |
| 24 | + |
| 25 | +SCRIPT_PATH = os.path.join(RESOURCE_PATH, 'default_handlers', 'model', 'code', 'empty_module.py') |
| 26 | +MNIST_PATH = os.path.join(RESOURCE_PATH, 'mnist') |
| 27 | +MODEL_PATH = os.path.join(MNIST_PATH, 'model', 'model.tar.gz') |
| 28 | + |
| 29 | +DATA_FILE = '07.csv' |
| 30 | +DATA_PATH = os.path.join(MNIST_PATH, 'images', DATA_FILE) |
| 31 | + |
| 32 | + |
| 33 | +def test_batch_transform(sagemaker_session, ecr_image, instance_type, framework_version): |
| 34 | + s3_prefix = 'mxnet-serving/mnist' |
| 35 | + model_data = sagemaker_session.upload_data(path=MODEL_PATH, key_prefix=s3_prefix) |
| 36 | + model = MXNetModel(model_data, |
| 37 | + 'SageMakerRole', |
| 38 | + SCRIPT_PATH, |
| 39 | + image=ecr_image, |
| 40 | + framework_version=framework_version, |
| 41 | + sagemaker_session=sagemaker_session) |
| 42 | + |
| 43 | + transformer = model.transformer(1, instance_type) |
| 44 | + with timeout.timeout_and_delete_model_with_transformer(transformer, sagemaker_session, minutes=20): |
| 45 | + input_data = sagemaker_session.upload_data(path=DATA_PATH, key_prefix=s3_prefix) |
| 46 | + |
| 47 | + job_name = utils.unique_name_from_base('test-mxnet-serving-batch') |
| 48 | + transformer.transform(input_data, content_type='text/csv', job_name=job_name) |
| 49 | + transformer.wait() |
| 50 | + |
| 51 | + prediction = _transform_result(sagemaker_session.boto_session, transformer.output_path) |
| 52 | + assert prediction == 7 |
| 53 | + |
| 54 | + |
| 55 | +def _transform_result(boto_session, output_path): |
| 56 | + s3 = boto_session.resource('s3', region_name=boto_session.region_name) |
| 57 | + |
| 58 | + parsed_url = urlparse(output_path) |
| 59 | + bucket_name = parsed_url.netloc |
| 60 | + prefix = parsed_url.path[1:] |
| 61 | + |
| 62 | + output_obj = s3.Object(bucket_name, '{}/{}.out'.format(prefix, DATA_FILE)) |
| 63 | + output = output_obj.get()['Body'].read().decode('utf-8') |
| 64 | + |
| 65 | + probabilities = json.loads(output)[0] |
| 66 | + return probabilities.index(max(probabilities)) |
0 commit comments