Skip to content

Commit 94700f5

Browse files
author
Dan
authored
add sagemaker inference as a dependency in setup.py and fix unit tests (#14)
1 parent 849b999 commit 94700f5

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def read(fname):
4848

4949
# We don't declare our dependency on mxnet here because we build with
5050
# different packages for different variants (e.g. mxnet-mkl and mxnet-cu90).
51+
install_requires=['sagemaker-inference==1.0.0'],
5152
extras_require={
5253
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock', 'sagemaker',
5354
'docker-compose', 'mxnet==1.4.0']

test/unit/test_handler_service.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@
1515
from mock import Mock, patch
1616
import mxnet as mx
1717
import pytest
18+
from sagemaker_inference import environment
19+
from sagemaker_inference.default_inference_handler import DefaultInferenceHandler
1820
from sagemaker_inference.transformer import Transformer
1921

22+
from sagemaker_mxnet_serving_container.default_inference_handler import DefaultGluonBlockInferenceHandler
2023
from sagemaker_mxnet_serving_container.handler_service import HandlerService
2124
from sagemaker_mxnet_serving_container.mxnet_module_transformer import MXNetModuleTransformer
2225

26+
MODULE_NAME = 'module_name'
27+
2328

2429
@patch('sagemaker_mxnet_serving_container.handler_service.HandlerService._user_module_transformer')
2530
def test_handler_service(user_module_transformer):
@@ -33,11 +38,14 @@ def __init__(self):
3338
self.transform_fn = Mock()
3439

3540

41+
@patch('sagemaker_inference.environment.Environment')
3642
@patch('importlib.import_module', return_value=UserModuleTransformFn())
37-
def test_user_module_transform_fn(import_module):
43+
def test_user_module_transform_fn(import_module, env):
44+
env.return_value.module_name = MODULE_NAME
3845
transformer = HandlerService._user_module_transformer()
3946

40-
assert transformer._transform_fn == import_module.return_value.transform_fn
47+
import_module.assert_called_once_with(MODULE_NAME)
48+
assert isinstance(transformer._default_inference_handler, DefaultInferenceHandler)
4149
assert isinstance(transformer, Transformer)
4250

4351

@@ -46,35 +54,40 @@ def __init__(self):
4654
self.model_fn = Mock()
4755

4856

49-
@patch('sagemaker_mxnet_serving_container.default_inference_handler.DefaultModuleInferenceHandler.default_predict_fn')
50-
@patch('sagemaker_mxnet_serving_container.default_inference_handler.DefaultModuleInferenceHandler.default_input_fn')
57+
@patch('sagemaker_inference.environment.Environment')
5158
@patch('importlib.import_module', return_value=UserModuleModelFn())
52-
def test_user_module_mxnet_module_transformer(import_module, input_fn, predict_fn):
59+
def test_user_module_mxnet_module_transformer(import_module, env):
60+
env.return_value.module_name = MODULE_NAME
5361
import_module.return_value.model_fn.return_value = mx.module.BaseModule()
5462

5563
transformer = HandlerService._user_module_transformer()
5664

65+
import_module.assert_called_once_with(MODULE_NAME)
5766
assert isinstance(transformer, MXNetModuleTransformer)
58-
assert transformer._input_fn == input_fn
59-
assert transformer._predict_fn == predict_fn
6067

6168

69+
@patch('sagemaker_inference.environment.Environment')
6270
@patch('sagemaker_mxnet_serving_container.default_inference_handler.DefaultMXNetInferenceHandler.default_model_fn')
63-
@patch('sagemaker_mxnet_serving_container.default_inference_handler.DefaultGluonBlockInferenceHandler.default_predict_fn') # noqa E501
6471
@patch('importlib.import_module', return_value=object())
65-
def test_user_module_mxnet_gluon_transformer(import_module, predict_fn, model_fn):
72+
def test_default_inference_handler_mxnet_gluon_transformer(import_module, model_fn, env):
73+
env.return_value.module_name = MODULE_NAME
6674
model_fn.return_value = mx.gluon.block.Block()
6775

6876
transformer = HandlerService._user_module_transformer()
6977

78+
import_module.assert_called_once_with(MODULE_NAME)
79+
model_fn.assert_called_once_with(environment.model_dir)
7080
assert isinstance(transformer, Transformer)
71-
assert transformer._predict_fn == predict_fn
72-
assert transformer._model_fn == model_fn
81+
assert isinstance(transformer._default_inference_handler, DefaultGluonBlockInferenceHandler)
7382

7483

84+
@patch('sagemaker_inference.environment.Environment')
7585
@patch('importlib.import_module', return_value=UserModuleModelFn())
76-
def test_user_module_unsupported(import_module):
86+
def test_user_module_unsupported(import_module, env):
87+
env.return_value.module_name = MODULE_NAME
88+
7789
with pytest.raises(ValueError) as e:
7890
HandlerService._user_module_transformer()
7991

92+
import_module.assert_called_once_with(MODULE_NAME)
8093
assert 'Unsupported model type' in str(e)

0 commit comments

Comments
 (0)