Skip to content

Add ragged batching support #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 53 additions & 24 deletions src/libtorch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1741,11 +1741,32 @@ ModelInstanceState::SetInputTensors(

input_names->emplace_back(input_name);

// The shape for the entire input patch, [total_batch_size, ...]
std::vector<int64_t> batchn_shape(
input_shape, input_shape + input_dims_count);
if (supports_batching_) {
batchn_shape[0] = total_batch_size;
// The shape for the entire input patch,
// [total_batch_size, ...] for non-ragged input and
// [total_element_count] for ragged input (non-nested tensor)
std::vector<int64_t> batchn_shape;
if (StateForModel()->IsInputRagged(input_name)) {
batchn_shape = std::vector<int64_t>{0};
for (size_t idx = 0; idx < request_count; idx++) {
TRITONBACKEND_Input* input;
RESPOND_AND_SET_NULL_IF_ERROR(
&((*responses)[idx]),
TRITONBACKEND_RequestInput(requests[idx], input_name, &input));
const int64_t* input_shape;
uint32_t input_dims_count;
RESPOND_AND_SET_NULL_IF_ERROR(
&((*responses)[idx]), TRITONBACKEND_InputProperties(
input, nullptr, nullptr, &input_shape,
&input_dims_count, nullptr, nullptr));

batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
}
}
else {
batchn_shape = std::vector<int64_t>(input_shape, input_shape + input_dims_count);
if (supports_batching_) {
batchn_shape[0] = total_batch_size;
}
}

// The input must be in contiguous CPU/GPU memory.
Expand Down Expand Up @@ -1866,28 +1887,36 @@ ModelInstanceState::ReadOutputTensors(

// Output tensors may not reside on the same device as model
torch::Device tensor_device = output_flat.device();
const auto memory_type = (tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
: TRITONSERVER_MEMORY_GPU;
const auto memory_id = (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();

// Batch output doesn't support string data type yet, as it is not trivial
// to parse string output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to mention this in the ragged batching docs too.

const BatchOutput* batch_output = StateForModel()->FindBatchOutput(name);
if (batch_output == nullptr) {
// Get output shape
std::vector<int64_t> batchn_shape;
auto shape = output_tensors[op_index].toTensor().sizes();
for (auto itr = shape.begin(); itr != shape.end(); itr++) {
batchn_shape.push_back(*itr);
}

// Get output shape
std::vector<int64_t> batchn_shape;
auto shape = output_tensors[op_index].toTensor().sizes();
for (auto itr = shape.begin(); itr != shape.end(); itr++) {
batchn_shape.push_back(*itr);
}
if (batchn_shape.size() == 0) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("output '") + name +
"' is a scalar which is not supported.")
.c_str());
}

if (batchn_shape.size() == 0) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("output '") + name +
"' is a scalar which is not supported.")
.c_str());
responder.ProcessTensor(
name, output_dtype, batchn_shape, output_buffer,
memory_type, memory_id);
} else {
responder.ProcessBatchOutput(
name, *batch_output, output_buffer, memory_type, memory_id);
}

responder.ProcessTensor(
name, output_dtype, batchn_shape, output_buffer,
(tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
: TRITONSERVER_MEMORY_GPU,
(tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index());

} else if (output_tensors[op_index].isList()) {
// Custom handling for string/bytes tensor...
torch::List<torch::jit::IValue> output_list =
Expand Down