Skip to content

Commit 03da652

Browse files
committed
fix the crashing when there are zero-size inputs
1 parent 304c2e8 commit 03da652

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

src/libtorch.cc

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,10 +2190,18 @@ ModelInstanceState::SetInputTensors(
21902190

21912191
(*input_tensors)[input_index_map_[input_name]] = input_list;
21922192
} else {
2193-
// Remove constness to align with the signature of torch::from_blob()
2194-
torch::Tensor input_tensor = torch::from_blob(
2195-
const_cast<char*>(input_buffer), batchn_shape, updated_options);
2196-
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2193+
if (batchn_byte_size) {
2194+
// Remove constness to align with the signature of torch::from_blob()
2195+
torch::Tensor input_tensor = torch::from_blob(
2196+
const_cast<char*>(input_buffer), batchn_shape, updated_options);
2197+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2198+
} else {
2199+
// torch:from_blob seems not working when the input size is 0
2200+
// create zero-lenght inputs directly
2201+
torch::Tensor input_tensor =
2202+
torch::zeros(batchn_shape, updated_options);
2203+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2204+
}
21972205
}
21982206
}
21992207

@@ -2222,9 +2230,15 @@ ModelInstanceState::SetInputTensors(
22222230
? options.device(torch::kCUDA, device_.index())
22232231
: options.device(torch::kCPU);
22242232

2225-
torch::Tensor input_tensor = torch::from_blob(
2226-
const_cast<char*>(dst_buffer), shape, updated_options);
2227-
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2233+
if (dst_buffer_byte_size) {
2234+
torch::Tensor input_tensor = torch::from_blob(
2235+
const_cast<char*>(dst_buffer), shape, updated_options);
2236+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2237+
} else {
2238+
// special handle when input has zero size
2239+
torch::Tensor input_tensor = torch::zeros(shape, updated_options);
2240+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2241+
}
22282242
}
22292243
}
22302244

0 commit comments

Comments
 (0)