18
18
import google .protobuf .json_format as json_format
19
19
from google .protobuf .message import DecodeError
20
20
from protobuf_to_dict import protobuf_to_dict
21
- from tensorflow .core .framework import tensor_pb2 # pylint: disable=no-name-in-module
22
- from tensorflow .python .framework import tensor_util # pylint: disable=no-name-in-module
23
21
24
22
from sagemaker .content_types import CONTENT_TYPE_JSON , CONTENT_TYPE_OCTET_STREAM , CONTENT_TYPE_CSV
25
23
from sagemaker .predictor import json_serializer , csv_serializer
26
- from tensorflow_serving .apis import predict_pb2 , classification_pb2 , inference_pb2 , regression_pb2
27
24
28
- _POSSIBLE_RESPONSES = [
29
- predict_pb2 .PredictResponse ,
30
- classification_pb2 .ClassificationResponse ,
31
- inference_pb2 .MultiInferenceResponse ,
32
- regression_pb2 .RegressionResponse ,
33
- tensor_pb2 .TensorProto ,
34
- ]
25
+
26
+ def _possible_responses ():
27
+ """
28
+ Returns: Possible available request types.
29
+ """
30
+ from tensorflow .core .framework import tensor_pb2 # pylint: disable=no-name-in-module
31
+ from tensorflow_serving .apis import (
32
+ predict_pb2 ,
33
+ classification_pb2 ,
34
+ inference_pb2 ,
35
+ regression_pb2 ,
36
+ )
37
+
38
+ return [
39
+ predict_pb2 .PredictResponse ,
40
+ classification_pb2 .ClassificationResponse ,
41
+ inference_pb2 .MultiInferenceResponse ,
42
+ regression_pb2 .RegressionResponse ,
43
+ tensor_pb2 .TensorProto ,
44
+ ]
45
+
35
46
36
47
REGRESSION_REQUEST = "RegressionRequest"
37
48
MULTI_INFERENCE_REQUEST = "MultiInferenceRequest"
@@ -88,7 +99,7 @@ def __call__(self, stream, content_type):
88
99
finally :
89
100
stream .close ()
90
101
91
- for possible_response in _POSSIBLE_RESPONSES :
102
+ for possible_response in _possible_responses () :
92
103
try :
93
104
response = possible_response ()
94
105
response .ParseFromString (data )
@@ -114,6 +125,9 @@ def __call__(self, data):
114
125
Args:
115
126
data:
116
127
"""
128
+
129
+ from tensorflow .core .framework import tensor_pb2 # pylint: disable=no-name-in-module
130
+
117
131
if isinstance (data , tensor_pb2 .TensorProto ):
118
132
return json_format .MessageToJson (data )
119
133
return json_serializer (data )
@@ -139,7 +153,7 @@ def __call__(self, stream, content_type):
139
153
finally :
140
154
stream .close ()
141
155
142
- for possible_response in _POSSIBLE_RESPONSES :
156
+ for possible_response in _possible_responses () :
143
157
try :
144
158
return protobuf_to_dict (json_format .Parse (data , possible_response ()))
145
159
except (UnicodeDecodeError , DecodeError , json_format .ParseError ):
@@ -164,6 +178,10 @@ def __call__(self, data):
164
178
data:
165
179
"""
166
180
to_serialize = data
181
+
182
+ from tensorflow .core .framework import tensor_pb2 # pylint: disable=no-name-in-module
183
+ from tensorflow .python .framework import tensor_util # pylint: disable=no-name-in-module
184
+
167
185
if isinstance (data , tensor_pb2 .TensorProto ):
168
186
to_serialize = tensor_util .MakeNdarray (data )
169
187
return csv_serializer (to_serialize )
0 commit comments