Skip to content

Commit af2d11a

Browse files
authored
Add ragged batching support (#95)
1 parent 4a8a870 commit af2d11a

File tree

1 file changed

+53
-24
lines changed

1 file changed

+53
-24
lines changed

src/libtorch.cc

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,11 +1741,32 @@ ModelInstanceState::SetInputTensors(
17411741

17421742
input_names->emplace_back(input_name);
17431743

1744-
// The shape for the entire input patch, [total_batch_size, ...]
1745-
std::vector<int64_t> batchn_shape(
1746-
input_shape, input_shape + input_dims_count);
1747-
if (supports_batching_) {
1748-
batchn_shape[0] = total_batch_size;
1744+
// The shape for the entire input patch,
1745+
// [total_batch_size, ...] for non-ragged input and
1746+
// [total_element_count] for ragged input (non-nested tensor)
1747+
std::vector<int64_t> batchn_shape;
1748+
if (StateForModel()->IsInputRagged(input_name)) {
1749+
batchn_shape = std::vector<int64_t>{0};
1750+
for (size_t idx = 0; idx < request_count; idx++) {
1751+
TRITONBACKEND_Input* input;
1752+
RESPOND_AND_SET_NULL_IF_ERROR(
1753+
&((*responses)[idx]),
1754+
TRITONBACKEND_RequestInput(requests[idx], input_name, &input));
1755+
const int64_t* input_shape;
1756+
uint32_t input_dims_count;
1757+
RESPOND_AND_SET_NULL_IF_ERROR(
1758+
&((*responses)[idx]), TRITONBACKEND_InputProperties(
1759+
input, nullptr, nullptr, &input_shape,
1760+
&input_dims_count, nullptr, nullptr));
1761+
1762+
batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
1763+
}
1764+
}
1765+
else {
1766+
batchn_shape = std::vector<int64_t>(input_shape, input_shape + input_dims_count);
1767+
if (supports_batching_) {
1768+
batchn_shape[0] = total_batch_size;
1769+
}
17491770
}
17501771

17511772
// The input must be in contiguous CPU/GPU memory.
@@ -1866,28 +1887,36 @@ ModelInstanceState::ReadOutputTensors(
18661887

18671888
// Output tensors may not reside on the same device as model
18681889
torch::Device tensor_device = output_flat.device();
1890+
const auto memory_type = (tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
1891+
: TRITONSERVER_MEMORY_GPU;
1892+
const auto memory_id = (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
1893+
1894+
// Batch output doesn't support string data type yet, as it is not trivial
1895+
// to parse string output
1896+
const BatchOutput* batch_output = StateForModel()->FindBatchOutput(name);
1897+
if (batch_output == nullptr) {
1898+
// Get output shape
1899+
std::vector<int64_t> batchn_shape;
1900+
auto shape = output_tensors[op_index].toTensor().sizes();
1901+
for (auto itr = shape.begin(); itr != shape.end(); itr++) {
1902+
batchn_shape.push_back(*itr);
1903+
}
18691904

1870-
// Get output shape
1871-
std::vector<int64_t> batchn_shape;
1872-
auto shape = output_tensors[op_index].toTensor().sizes();
1873-
for (auto itr = shape.begin(); itr != shape.end(); itr++) {
1874-
batchn_shape.push_back(*itr);
1875-
}
1905+
if (batchn_shape.size() == 0) {
1906+
return TRITONSERVER_ErrorNew(
1907+
TRITONSERVER_ERROR_INVALID_ARG,
1908+
(std::string("output '") + name +
1909+
"' is a scalar which is not supported.")
1910+
.c_str());
1911+
}
18761912

1877-
if (batchn_shape.size() == 0) {
1878-
return TRITONSERVER_ErrorNew(
1879-
TRITONSERVER_ERROR_INVALID_ARG,
1880-
(std::string("output '") + name +
1881-
"' is a scalar which is not supported.")
1882-
.c_str());
1913+
responder.ProcessTensor(
1914+
name, output_dtype, batchn_shape, output_buffer,
1915+
memory_type, memory_id);
1916+
} else {
1917+
responder.ProcessBatchOutput(
1918+
name, *batch_output, output_buffer, memory_type, memory_id);
18831919
}
1884-
1885-
responder.ProcessTensor(
1886-
name, output_dtype, batchn_shape, output_buffer,
1887-
(tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
1888-
: TRITONSERVER_MEMORY_GPU,
1889-
(tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index());
1890-
18911920
} else if (output_tensors[op_index].isList()) {
18921921
// Custom handling for string/bytes tensor...
18931922
torch::List<torch::jit::IValue> output_list =

0 commit comments

Comments
 (0)