Skip to content

Commit 624c782

Browse files
author
Rui Wang Napieralski
committed
move import of neomxnet into inference functions
1 parent aae58db commit 624c782

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/data/mxnet_mnist/mnist_neo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import mxnet as mx
2121
import numpy as np
22-
import neomxnet # noqa: F401
2322

2423

2524
def load_data(path):
@@ -106,6 +105,8 @@ def train(
106105

107106

108107
def model_fn(path_to_model_files):
108+
import neomxnet # noqa: F401
109+
109110
ctx = mx.cpu()
110111
sym, arg_params, aux_params = mx.model.load_checkpoint(
111112
os.path.join(path_to_model_files, "compiled"), 0
@@ -119,6 +120,8 @@ def model_fn(path_to_model_files):
119120

120121

121122
def transform_fn(mod, payload, input_content_type, requested_output_content_type):
123+
import neomxnet # noqa: F401
124+
122125
if input_content_type != "application/vnd+python.numpy+binary":
123126
raise RuntimeError("Input content type must be application/vnd+python.numpy+binary")
124127

0 commit comments

Comments
 (0)