Skip to content

Commit 2ee203f

Browse files
authored
change: lazy import of tensorflow module (#1062)
1 parent 810e96b commit 2ee203f

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

src/sagemaker/tensorflow/predictor.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,31 @@
1818
import google.protobuf.json_format as json_format
1919
from google.protobuf.message import DecodeError
2020
from protobuf_to_dict import protobuf_to_dict
21-
from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module
22-
from tensorflow.python.framework import tensor_util # pylint: disable=no-name-in-module
2321

2422
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_OCTET_STREAM, CONTENT_TYPE_CSV
2523
from sagemaker.predictor import json_serializer, csv_serializer
26-
from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2
2724

28-
_POSSIBLE_RESPONSES = [
29-
predict_pb2.PredictResponse,
30-
classification_pb2.ClassificationResponse,
31-
inference_pb2.MultiInferenceResponse,
32-
regression_pb2.RegressionResponse,
33-
tensor_pb2.TensorProto,
34-
]
25+
26+
def _possible_responses():
27+
"""
28+
Returns: Possible available request types.
29+
"""
30+
from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module
31+
from tensorflow_serving.apis import (
32+
predict_pb2,
33+
classification_pb2,
34+
inference_pb2,
35+
regression_pb2,
36+
)
37+
38+
return [
39+
predict_pb2.PredictResponse,
40+
classification_pb2.ClassificationResponse,
41+
inference_pb2.MultiInferenceResponse,
42+
regression_pb2.RegressionResponse,
43+
tensor_pb2.TensorProto,
44+
]
45+
3546

3647
REGRESSION_REQUEST = "RegressionRequest"
3748
MULTI_INFERENCE_REQUEST = "MultiInferenceRequest"
@@ -88,7 +99,7 @@ def __call__(self, stream, content_type):
8899
finally:
89100
stream.close()
90101

91-
for possible_response in _POSSIBLE_RESPONSES:
102+
for possible_response in _possible_responses():
92103
try:
93104
response = possible_response()
94105
response.ParseFromString(data)
@@ -114,6 +125,9 @@ def __call__(self, data):
114125
Args:
115126
data:
116127
"""
128+
129+
from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module
130+
117131
if isinstance(data, tensor_pb2.TensorProto):
118132
return json_format.MessageToJson(data)
119133
return json_serializer(data)
@@ -139,7 +153,7 @@ def __call__(self, stream, content_type):
139153
finally:
140154
stream.close()
141155

142-
for possible_response in _POSSIBLE_RESPONSES:
156+
for possible_response in _possible_responses():
143157
try:
144158
return protobuf_to_dict(json_format.Parse(data, possible_response()))
145159
except (UnicodeDecodeError, DecodeError, json_format.ParseError):
@@ -164,6 +178,10 @@ def __call__(self, data):
164178
data:
165179
"""
166180
to_serialize = data
181+
182+
from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module
183+
from tensorflow.python.framework import tensor_util # pylint: disable=no-name-in-module
184+
167185
if isinstance(data, tensor_pb2.TensorProto):
168186
to_serialize = tensor_util.MakeNdarray(data)
169187
return csv_serializer(to_serialize)

0 commit comments

Comments
 (0)