15
15
from mock import Mock , patch
16
16
import mxnet as mx
17
17
import pytest
18
+ from sagemaker_inference import environment
19
+ from sagemaker_inference .default_inference_handler import DefaultInferenceHandler
18
20
from sagemaker_inference .transformer import Transformer
19
21
22
+ from sagemaker_mxnet_serving_container .default_inference_handler import DefaultGluonBlockInferenceHandler
20
23
from sagemaker_mxnet_serving_container .handler_service import HandlerService
21
24
from sagemaker_mxnet_serving_container .mxnet_module_transformer import MXNetModuleTransformer
22
25
26
+ MODULE_NAME = 'module_name'
27
+
23
28
24
29
@patch ('sagemaker_mxnet_serving_container.handler_service.HandlerService._user_module_transformer' )
25
30
def test_handler_service (user_module_transformer ):
@@ -33,11 +38,14 @@ def __init__(self):
33
38
self .transform_fn = Mock ()
34
39
35
40
41
+ @patch ('sagemaker_inference.environment.Environment' )
36
42
@patch ('importlib.import_module' , return_value = UserModuleTransformFn ())
37
- def test_user_module_transform_fn (import_module ):
43
+ def test_user_module_transform_fn (import_module , env ):
44
+ env .return_value .module_name = MODULE_NAME
38
45
transformer = HandlerService ._user_module_transformer ()
39
46
40
- assert transformer ._transform_fn == import_module .return_value .transform_fn
47
+ import_module .assert_called_once_with (MODULE_NAME )
48
+ assert isinstance (transformer ._default_inference_handler , DefaultInferenceHandler )
41
49
assert isinstance (transformer , Transformer )
42
50
43
51
@@ -46,35 +54,40 @@ def __init__(self):
46
54
self .model_fn = Mock ()
47
55
48
56
49
- @patch ('sagemaker_mxnet_serving_container.default_inference_handler.DefaultModuleInferenceHandler.default_predict_fn' )
50
- @patch ('sagemaker_mxnet_serving_container.default_inference_handler.DefaultModuleInferenceHandler.default_input_fn' )
57
+ @patch ('sagemaker_inference.environment.Environment' )
51
58
@patch ('importlib.import_module' , return_value = UserModuleModelFn ())
52
- def test_user_module_mxnet_module_transformer (import_module , input_fn , predict_fn ):
59
+ def test_user_module_mxnet_module_transformer (import_module , env ):
60
+ env .return_value .module_name = MODULE_NAME
53
61
import_module .return_value .model_fn .return_value = mx .module .BaseModule ()
54
62
55
63
transformer = HandlerService ._user_module_transformer ()
56
64
65
+ import_module .assert_called_once_with (MODULE_NAME )
57
66
assert isinstance (transformer , MXNetModuleTransformer )
58
- assert transformer ._input_fn == input_fn
59
- assert transformer ._predict_fn == predict_fn
60
67
61
68
69
+ @patch ('sagemaker_inference.environment.Environment' )
62
70
@patch ('sagemaker_mxnet_serving_container.default_inference_handler.DefaultMXNetInferenceHandler.default_model_fn' )
63
- @patch ('sagemaker_mxnet_serving_container.default_inference_handler.DefaultGluonBlockInferenceHandler.default_predict_fn' ) # noqa E501
64
71
@patch ('importlib.import_module' , return_value = object ())
65
- def test_user_module_mxnet_gluon_transformer (import_module , predict_fn , model_fn ):
72
+ def test_default_inference_handler_mxnet_gluon_transformer (import_module , model_fn , env ):
73
+ env .return_value .module_name = MODULE_NAME
66
74
model_fn .return_value = mx .gluon .block .Block ()
67
75
68
76
transformer = HandlerService ._user_module_transformer ()
69
77
78
+ import_module .assert_called_once_with (MODULE_NAME )
79
+ model_fn .assert_called_once_with (environment .model_dir )
70
80
assert isinstance (transformer , Transformer )
71
- assert transformer ._predict_fn == predict_fn
72
- assert transformer ._model_fn == model_fn
81
+ assert isinstance (transformer ._default_inference_handler , DefaultGluonBlockInferenceHandler )
73
82
74
83
84
+ @patch ('sagemaker_inference.environment.Environment' )
75
85
@patch ('importlib.import_module' , return_value = UserModuleModelFn ())
76
- def test_user_module_unsupported (import_module ):
86
+ def test_user_module_unsupported (import_module , env ):
87
+ env .return_value .module_name = MODULE_NAME
88
+
77
89
with pytest .raises (ValueError ) as e :
78
90
HandlerService ._user_module_transformer ()
79
91
92
+ import_module .assert_called_once_with (MODULE_NAME )
80
93
assert 'Unsupported model type' in str (e )
0 commit comments