Skip to content

add sagemaker inference as a dependency in setup.py and fix unit tests #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def read(fname):

# We don't declare our dependency on mxnet here because we build with
# different packages for different variants (e.g. mxnet-mkl and mxnet-cu90).
install_requires=['sagemaker-inference==1.0.0'],
extras_require={
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock', 'sagemaker',
'docker-compose', 'mxnet==1.4.0']
Expand Down
37 changes: 25 additions & 12 deletions test/unit/test_handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
from mock import Mock, patch
import mxnet as mx
import pytest
from sagemaker_inference import environment
from sagemaker_inference.default_inference_handler import DefaultInferenceHandler
from sagemaker_inference.transformer import Transformer

from sagemaker_mxnet_serving_container.default_inference_handler import DefaultGluonBlockInferenceHandler
from sagemaker_mxnet_serving_container.handler_service import HandlerService
from sagemaker_mxnet_serving_container.mxnet_module_transformer import MXNetModuleTransformer

MODULE_NAME = 'module_name'


@patch('sagemaker_mxnet_serving_container.handler_service.HandlerService._user_module_transformer')
def test_handler_service(user_module_transformer):
Expand All @@ -33,11 +38,14 @@ def __init__(self):
self.transform_fn = Mock()


@patch('sagemaker_inference.environment.Environment')
@patch('importlib.import_module', return_value=UserModuleTransformFn())
def test_user_module_transform_fn(import_module):
def test_user_module_transform_fn(import_module, env):
env.return_value.module_name = MODULE_NAME
transformer = HandlerService._user_module_transformer()

assert transformer._transform_fn == import_module.return_value.transform_fn
import_module.assert_called_once_with(MODULE_NAME)
assert isinstance(transformer._default_inference_handler, DefaultInferenceHandler)
assert isinstance(transformer, Transformer)


Expand All @@ -46,35 +54,40 @@ def __init__(self):
self.model_fn = Mock()


@patch('sagemaker_mxnet_serving_container.default_inference_handler.DefaultModuleInferenceHandler.default_predict_fn')
@patch('sagemaker_mxnet_serving_container.default_inference_handler.DefaultModuleInferenceHandler.default_input_fn')
@patch('sagemaker_inference.environment.Environment')
@patch('importlib.import_module', return_value=UserModuleModelFn())
def test_user_module_mxnet_module_transformer(import_module, input_fn, predict_fn):
def test_user_module_mxnet_module_transformer(import_module, env):
env.return_value.module_name = MODULE_NAME
import_module.return_value.model_fn.return_value = mx.module.BaseModule()

transformer = HandlerService._user_module_transformer()

import_module.assert_called_once_with(MODULE_NAME)
assert isinstance(transformer, MXNetModuleTransformer)
assert transformer._input_fn == input_fn
assert transformer._predict_fn == predict_fn


@patch('sagemaker_inference.environment.Environment')
@patch('sagemaker_mxnet_serving_container.default_inference_handler.DefaultMXNetInferenceHandler.default_model_fn')
@patch('sagemaker_mxnet_serving_container.default_inference_handler.DefaultGluonBlockInferenceHandler.default_predict_fn') # noqa E501
@patch('importlib.import_module', return_value=object())
def test_user_module_mxnet_gluon_transformer(import_module, predict_fn, model_fn):
def test_default_inference_handler_mxnet_gluon_transformer(import_module, model_fn, env):
env.return_value.module_name = MODULE_NAME
model_fn.return_value = mx.gluon.block.Block()

transformer = HandlerService._user_module_transformer()

import_module.assert_called_once_with(MODULE_NAME)
model_fn.assert_called_once_with(environment.model_dir)
assert isinstance(transformer, Transformer)
assert transformer._predict_fn == predict_fn
assert transformer._model_fn == model_fn
assert isinstance(transformer._default_inference_handler, DefaultGluonBlockInferenceHandler)


@patch('sagemaker_inference.environment.Environment')
@patch('importlib.import_module', return_value=UserModuleModelFn())
def test_user_module_unsupported(import_module):
def test_user_module_unsupported(import_module, env):
env.return_value.module_name = MODULE_NAME

with pytest.raises(ValueError) as e:
HandlerService._user_module_transformer()

import_module.assert_called_once_with(MODULE_NAME)
assert 'Unsupported model type' in str(e)