Skip to content

Commit b16ff74

Browse files
authored
Fix deserialization of dicts for json predict requests (#76)
* Allow dict of list/ndarray for predict * Add unit test * Add integ test on dict of lists prediction * Update docstrings
1 parent d442ac7 commit b16ff74

File tree

4 files changed

+53
-23
lines changed

4 files changed

+53
-23
lines changed

src/tf_container/proxy_client.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ def _create_classification_request(self, data):
151151
def _create_feature_dict_list(self, data):
152152
"""
153153
Parses the input data and returns a [dict<string, iterable>] which will be used to create the tf examples.
154-
If the input data is not a dict, a dictionary will be created with the default predict key PREDICT_INPUTS
154+
If the input data is not a dict, a dictionary will be created with the default key PREDICT_INPUTS.
155+
Used on the code path for creating ClassificationRequests.
155156
156157
Examples:
157158
input => output
@@ -184,43 +185,46 @@ def _raise_not_implemented_exception(self, data):
184185

185186
def _create_input_map(self, data):
186187
"""
187-
Parses the input data and returns a dict<string, TensorProto> which will be used to create the predict request.
188+
Parses the input data and returns a dict<string, TensorProto> which will be used to create the PredictRequest.
188189
If the input data is not a dict, a dictionary will be created with the default predict key PREDICT_INPUTS
189190
190191
input.
191192
192193
Examples:
193194
input => output
194-
{'inputs': tensor_proto} => {'inputs': tensor_proto}
195+
-------------------------------------------------
195196
tensor_proto => {PREDICT_INPUTS: tensor_proto}
196-
[1,2,3] => {PREDICT_INPUTS: tensor_proto(1,2,3)}
197+
{'custom_tensor_name': tensor_proto} => {'custom_tensor_name': TensorProto}
198+
[1,2,3] => {PREDICT_INPUTS: TensorProto(1,2,3)}
199+
{'custom_tensor_name': [1, 2, 3]} => {'custom_tensor_name': TensorProto(1,2,3)}
197200
Args:
198-
data: request data. Can be any instance of dict<string, tensor_proto>, tensor_proto or any array like data.
201+
data: request data. Can be any of: ndarray-like, TensorProto, dict<str, TensorProto>, dict<str, ndarray-like>
199202
200203
Returns:
201204
dict<string, tensor_proto>
202205
203206
204207
"""
205-
msg = """Unsupported request data format: {}.
206-
Valid formats: tensor_pb2.TensorProto, dict<string, tensor_pb2.TensorProto> and predict_pb2.PredictRequest"""
207-
208208
if isinstance(data, dict):
209-
if all(isinstance(v, tensor_pb2.TensorProto) for k, v in data.items()):
210-
return data
211-
raise ValueError(msg.format(data))
209+
return {k: self._value_to_tensor(v) for k, v in data.items()}
210+
211+
# When input data is not a dict, no tensor names are given, so use default
212+
return {self.input_tensor_name: self._value_to_tensor(data)}
212213

213-
if isinstance(data, tensor_pb2.TensorProto):
214-
return {self.input_tensor_name: data}
214+
def _value_to_tensor(self, value):
215+
"""Converts the given value to a tensor_pb2.TensorProto. Used on code path for creating PredictRequests."""
216+
if isinstance(value, tensor_pb2.TensorProto):
217+
return value
215218

219+
msg = """Unable to convert value to TensorProto: {}.
220+
Valid formats: tensor_pb2.TensorProto, list, numpy.ndarray"""
216221
try:
217222
# TODO: tensorflow container supports prediction requests with ONLY one tensor as input
218223
input_type = self.input_type_map.values()[0]
219-
ndarray = np.asarray(data)
220-
tensor_proto = make_tensor_proto(values=ndarray, dtype=input_type, shape=ndarray.shape)
221-
return {self.input_tensor_name: tensor_proto}
222-
except:
223-
raise ValueError(msg.format(data))
224+
ndarray = np.asarray(value)
225+
return make_tensor_proto(values=ndarray, dtype=input_type, shape=ndarray.shape)
226+
except Exception:
227+
raise ValueError(msg.format(value))
224228

225229

226230
def _create_tf_example(feature_dict):

src/tf_container/serve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(self, grpc_proxy_client, transform_fn=None, input_fn=None, output_f
160160

161161
@staticmethod
162162
def _parse_json_request(serialized_data):
163-
'''
163+
"""
164164
json deserialization works in the following order:
165165
1 - tries to deserialize the payload as a tensor using google.protobuf.json_format.Parse(
166166
payload, tensor_pb2.TensorProto())
@@ -170,7 +170,7 @@ def _parse_json_request(serialized_data):
170170
171171
Returns:
172172
deserialized object
173-
'''
173+
"""
174174
try:
175175
return json_format.Parse(serialized_data, tensor_pb2.TensorProto())
176176
except json_format.ParseError:

test/integ/container_tests/layers_prediction.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,16 @@ def test_json_request():
6060
prediction_result = json.loads(serialized_output)
6161

6262
assert len(prediction_result['outputs']['probabilities']['floatVal']) == 10
63+
64+
65+
def test_json_dict_of_lists():
66+
data = {'inputs': [x for x in xrange(784)]}
67+
68+
url = "http://localhost:8080/invocations"
69+
serialized_output = requests.post(url,
70+
json.dumps(data),
71+
headers={'Content-type': 'application/json'}).content
72+
73+
prediction_result = json.loads(serialized_output)
74+
75+
assert len(prediction_result['outputs']['probabilities']['floatVal']) == 10

test/unit/test_proxy_client.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def set_up():
4545
patcher.start()
4646
from tf_container.proxy_client import GRPCProxyClient
4747
proxy_client = GRPCProxyClient(9000, input_tensor_name='inputs', signature_name='serving_default')
48+
proxy_client.input_type_map['sometype'] = 'somedtype'
4849

4950
yield mock, proxy_client
5051
patcher.stop()
@@ -253,15 +254,27 @@ def test_predict_with_predict_request(set_up, set_up_requests):
253254
assert prediction == predict_fn.return_value
254255

255256

256-
def test_predict_with_invalid_payload(set_up, set_up_requests):
257-
mock, proxy_client = set_up
257+
@patch('tf_container.proxy_client.make_tensor_proto', side_effect=Exception('tensor proto failed!'))
258+
def test_predict_with_invalid_payload(make_tensor_proto, set_up, set_up_requests):
259+
_, proxy_client = set_up
258260

259261
data = complex('1+2j')
260262

261263
with pytest.raises(ValueError) as error:
262264
proxy_client.predict(data)
263265

264-
assert 'Unsupported request data format' in str(error)
266+
assert 'Unable to convert value to TensorProto' in str(error)
267+
268+
269+
@patch('tf_container.proxy_client.make_tensor_proto', return_value='MyTensorProto')
270+
def test_predict_create_input_map_with_dict_of_lists(make_tensor_proto, set_up, set_up_requests):
271+
_, proxy_client = set_up
272+
273+
data = {'mytensor': [1, 2, 3]}
274+
275+
result = proxy_client._create_input_map(data)
276+
assert result == {'mytensor': 'MyTensorProto'}
277+
make_tensor_proto.assert_called_once()
265278

266279

267280
def test_classification_with_classification_request(set_up, set_up_requests):

0 commit comments

Comments
 (0)