Skip to content

Commit 8c2bcc8

Browse files
committed
Add support for batch_input
1 parent 588c6ac commit 8c2bcc8

File tree

2 files changed

+111
-44
lines changed

2 files changed

+111
-44
lines changed

src/libtorch.cc

Lines changed: 110 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
#include <stdint.h>
28+
29+
#include <cstdint>
2830
#include <exception>
31+
2932
#include "libtorch_utils.h"
3033
#include "triton/backend/backend_common.h"
3134
#include "triton/backend/backend_input_collector.h"
@@ -502,6 +505,10 @@ class ModelInstanceState : public BackendModelInstance {
502505
triton::common::TritonJson::Value& sequence_batching,
503506
const std::string& control_kind, bool required, bool* have_control);
504507
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
508+
void AddInputToMap(
509+
NamingConvention naming_convention,
510+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
511+
const uint32_t index);
505512
TRITONSERVER_Error* ValidateOutputs();
506513
void Execute(
507514
std::vector<TRITONBACKEND_Response*>* responses,
@@ -538,6 +545,7 @@ class ModelInstanceState : public BackendModelInstance {
538545
// Map from configuration name for an input to the index of
539546
// that input in the model.
540547
std::unordered_map<std::string, int> input_index_map_;
548+
uint32_t batch_input_count_ = 0;
541549

542550
// Map from configuration name for an output to the index of
543551
// that output in the model.
@@ -607,6 +615,12 @@ ModelInstanceState::ModelInstanceState(
607615
if (model_state->ModelConfig().Find("input", &inputs)) {
608616
expected_input_cnt = inputs.ArraySize();
609617
}
618+
619+
triton::common::TritonJson::Value config_batch_inputs;
620+
if (model_state->ModelConfig().Find("batch_input", &config_batch_inputs)) {
621+
batch_input_count_ = config_batch_inputs.ArraySize();
622+
expected_input_cnt += batch_input_count_;
623+
}
610624
}
611625

612626
// If this is a sequence model then make sure that the required
@@ -757,6 +771,43 @@ ModelInstanceState::ValidateTypedSequenceControl(
757771
return nullptr; // success
758772
}
759773

774+
void
775+
ModelInstanceState::AddInputToMap(
776+
NamingConvention naming_convention,
777+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
778+
const uint32_t index)
779+
{
780+
std::string deliminator = "__";
781+
782+
if (is_dict_input_) {
783+
// If dictionary, index is irrelevant but we use the map to store the
784+
// input names since they are the keys for the dictionary
785+
input_index_map_[io_name] = index;
786+
} else {
787+
switch (naming_convention) {
788+
case NamingConvention::FORWARD_ARGUMENT: {
789+
auto itr =
790+
std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name);
791+
if (itr != allowed_inputs.end()) {
792+
input_index_map_[io_name] =
793+
std::distance(allowed_inputs.begin(), itr);
794+
}
795+
return;
796+
}
797+
case NamingConvention::NAMED_INDEX: {
798+
int start_pos = io_name.find(deliminator);
799+
int ip_index = std::atoi(io_name.substr(start_pos + 2).c_str());
800+
input_index_map_[io_name] = ip_index;
801+
return;
802+
}
803+
case NamingConvention::STRICT_CONFIG_ORDERING: {
804+
input_index_map_[io_name] = index;
805+
return;
806+
}
807+
}
808+
}
809+
}
810+
760811
TRITONSERVER_Error*
761812
ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
762813
{
@@ -822,8 +873,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
822873

823874
triton::common::TritonJson::Value ios;
824875
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("input", &ios));
825-
std::string deliminator = "__";
826-
int ip_index = 0;
827876

828877
if (ios.ArraySize() == 0) {
829878
return TRITONSERVER_ErrorNew(
@@ -842,34 +891,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
842891
// Validate name
843892
std::string io_name;
844893
RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
845-
if (is_dict_input_) {
846-
// If dictionary, index is irrelevant but we use the map to store the
847-
// input names since they are the keys for the dictionary
848-
input_index_map_[io_name] = i;
849-
} else {
850-
switch (naming_convention) {
851-
case NamingConvention::FORWARD_ARGUMENT: {
852-
auto itr =
853-
std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name);
854-
if (itr != allowed_inputs.end()) {
855-
input_index_map_[io_name] =
856-
std::distance(allowed_inputs.begin(), itr);
857-
}
858-
break;
859-
}
860-
case NamingConvention::NAMED_INDEX: {
861-
int start_pos = io_name.find(deliminator);
862-
ip_index = std::atoi(io_name.substr(start_pos + 2).c_str());
863-
input_index_map_[io_name] = ip_index;
864-
break;
865-
}
866-
case NamingConvention::STRICT_CONFIG_ORDERING: {
867-
input_index_map_[io_name] = i;
868-
break;
869-
}
870-
}
871-
}
872-
894+
AddInputToMap(naming_convention, allowed_inputs, io_name, i);
873895
// Validate data type
874896
std::string io_dtype;
875897
RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype));
@@ -906,6 +928,18 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
906928
}
907929
}
908930

931+
triton::common::TritonJson::Value batch_inputs;
932+
RETURN_IF_ERROR(
933+
model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
934+
size_t i = 0;
935+
for (const auto& batch_input : StateForModel()->BatchInputs()) {
936+
for (const auto& input_name : batch_input.TargetNames()) {
937+
AddInputToMap(
938+
naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
939+
i++;
940+
}
941+
}
942+
909943
return nullptr; // success
910944
}
911945

@@ -1312,12 +1346,12 @@ ModelInstanceState::Execute(
13121346
torch::jit::overrideCanFuseOnCPU(false);
13131347
torch::jit::overrideCanFuseOnGPU(false);
13141348
torch::jit::setTensorExprFuserEnabled(false);
1315-
torch::jit::fuser::cuda::setEnabled(true);
1349+
torch::jit::fuser::cuda::setEnabled(true);
13161350
} else {
13171351
torch::jit::overrideCanFuseOnCPU(true);
13181352
torch::jit::overrideCanFuseOnGPU(true);
13191353
torch::jit::setTensorExprFuserEnabled(true);
1320-
torch::jit::fuser::cuda::setEnabled(false);
1354+
torch::jit::fuser::cuda::setEnabled(false);
13211355
}
13221356
}
13231357

@@ -1725,7 +1759,8 @@ ModelInstanceState::SetInputTensors(
17251759
// request as the representative for the input tensors.
17261760
uint32_t input_count;
17271761
RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count));
1728-
input_tensors->resize(input_count);
1762+
1763+
input_tensors->resize(input_count + batch_input_count_);
17291764
for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) {
17301765
TRITONBACKEND_Input* input;
17311766
RETURN_IF_ERROR(
@@ -1761,9 +1796,9 @@ ModelInstanceState::SetInputTensors(
17611796

17621797
batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
17631798
}
1764-
}
1765-
else {
1766-
batchn_shape = std::vector<int64_t>(input_shape, input_shape + input_dims_count);
1799+
} else {
1800+
batchn_shape =
1801+
std::vector<int64_t>(input_shape, input_shape + input_dims_count);
17671802
if (supports_batching_) {
17681803
batchn_shape[0] = total_batch_size;
17691804
}
@@ -1828,6 +1863,36 @@ ModelInstanceState::SetInputTensors(
18281863
}
18291864
}
18301865

1866+
for (const auto& batch_input : StateForModel()->BatchInputs()) {
1867+
std::vector<int64_t> shape;
1868+
collector->BatchInputShape(batch_input, &shape);
1869+
1870+
for (const auto& input_name : batch_input.TargetNames()) {
1871+
input_names->emplace_back(input_name.c_str());
1872+
1873+
const char* dst_buffer;
1874+
size_t dst_buffer_byte_size;
1875+
TRITONSERVER_MemoryType dst_memory_type;
1876+
int64_t dst_memory_type_id;
1877+
1878+
// Batch inputs are always created on CPU
1879+
RESPOND_ALL_AND_SET_NULL_IF_ERROR(
1880+
(*responses), responses->size(),
1881+
collector->ProcessBatchInput(
1882+
batch_input, nullptr, 0, {{TRITONSERVER_MEMORY_CPU, 0}},
1883+
&dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1884+
&dst_memory_type_id));
1885+
1886+
const auto torch_dtype =
1887+
ConvertDataTypeToTorchType(batch_input.DataType());
1888+
1889+
torch::Tensor input_tensor = torch::from_blob(
1890+
const_cast<char*>(dst_buffer), shape,
1891+
updated_options.dtype(torch_dtype.second));
1892+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
1893+
}
1894+
}
1895+
18311896
// Finalize...
18321897
*cuda_copy |= collector->Finalize();
18331898

@@ -1887,9 +1952,11 @@ ModelInstanceState::ReadOutputTensors(
18871952

18881953
// Output tensors may not reside on the same device as model
18891954
torch::Device tensor_device = output_flat.device();
1890-
const auto memory_type = (tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
1891-
: TRITONSERVER_MEMORY_GPU;
1892-
const auto memory_id = (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
1955+
const auto memory_type = (tensor_device.type() == torch::kCPU)
1956+
? TRITONSERVER_MEMORY_CPU
1957+
: TRITONSERVER_MEMORY_GPU;
1958+
const auto memory_id =
1959+
(tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
18931960

18941961
// Batch output doesn't support string data type yet, as it is not trivial
18951962
// to parse string output
@@ -1906,16 +1973,16 @@ ModelInstanceState::ReadOutputTensors(
19061973
return TRITONSERVER_ErrorNew(
19071974
TRITONSERVER_ERROR_INVALID_ARG,
19081975
(std::string("output '") + name +
1909-
"' is a scalar which is not supported.")
1976+
"' is a scalar which is not supported.")
19101977
.c_str());
19111978
}
19121979

19131980
responder.ProcessTensor(
1914-
name, output_dtype, batchn_shape, output_buffer,
1915-
memory_type, memory_id);
1981+
name, output_dtype, batchn_shape, output_buffer, memory_type,
1982+
memory_id);
19161983
} else {
19171984
responder.ProcessBatchOutput(
1918-
name, *batch_output, output_buffer, memory_type, memory_id);
1985+
name, *batch_output, output_buffer, memory_type, memory_id);
19191986
}
19201987
} else if (output_tensors[op_index].isList()) {
19211988
// Custom handling for string/bytes tensor...

src/libtorch_utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ ParseParameter(
152152
#ifdef TRITON_ENABLE_GPU
153153
TRITONSERVER_Error*
154154
ConvertCUDAStatusToTritonError(
155-
cudaError_t cuda_error,TRITONSERVER_Error_Code code, const char* msg)
155+
cudaError_t cuda_error, TRITONSERVER_Error_Code code, const char* msg)
156156
{
157157
if (cuda_error != cudaSuccess) {
158158
return TRITONSERVER_ErrorNew(

0 commit comments

Comments
 (0)