Skip to content

Commit d900538

Browse files
fix the crashing when there are zero-size inputs (#120)
* fix the crashing when there are zero-size inputs * Typo --------- Co-authored-by: Iman Tabrizian <[email protected]>
1 parent 48e2e29 commit d900538

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

src/libtorch.cc

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,10 +2136,18 @@ ModelInstanceState::SetInputTensors(
21362136

21372137
(*input_tensors)[input_index_map_[input_name]] = input_list;
21382138
} else {
2139-
// Remove constness to align with the signature of torch::from_blob()
2140-
torch::Tensor input_tensor = torch::from_blob(
2141-
const_cast<char*>(input_buffer), batchn_shape, updated_options);
2142-
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2139+
if (batchn_byte_size) {
2140+
// Remove constness to align with the signature of torch::from_blob()
2141+
torch::Tensor input_tensor = torch::from_blob(
2142+
const_cast<char*>(input_buffer), batchn_shape, updated_options);
2143+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2144+
} else {
2145+
// torch:from_blob seems not working when the input size is 0
2146+
// create zero-length inputs directly
2147+
torch::Tensor input_tensor =
2148+
torch::zeros(batchn_shape, updated_options);
2149+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2150+
}
21432151
}
21442152
}
21452153

@@ -2168,9 +2176,15 @@ ModelInstanceState::SetInputTensors(
21682176
? options.device(torch::kCUDA, device_.index())
21692177
: options.device(torch::kCPU);
21702178

2171-
torch::Tensor input_tensor = torch::from_blob(
2172-
const_cast<char*>(dst_buffer), shape, updated_options);
2173-
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2179+
if (dst_buffer_byte_size) {
2180+
torch::Tensor input_tensor = torch::from_blob(
2181+
const_cast<char*>(dst_buffer), shape, updated_options);
2182+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2183+
} else {
2184+
// special handle when input has zero size
2185+
torch::Tensor input_tensor = torch::zeros(shape, updated_options);
2186+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
2187+
}
21742188
}
21752189
}
21762190

0 commit comments

Comments
 (0)