Skip to content

Commit 4a7753d

Browse files
committed
Add support for batch_input
1 parent 588c6ac commit 4a7753d

File tree

1 file changed

+90
-31
lines changed

1 file changed

+90
-31
lines changed

src/libtorch.cc

Lines changed: 90 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
#include <stdint.h>
28+
#include <cstdint>
2829
#include <exception>
2930
#include "libtorch_utils.h"
3031
#include "triton/backend/backend_common.h"
@@ -502,6 +503,11 @@ class ModelInstanceState : public BackendModelInstance {
502503
triton::common::TritonJson::Value& sequence_batching,
503504
const std::string& control_kind, bool required, bool* have_control);
504505
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
506+
void AddInputToMap(
507+
NamingConvention naming_convention,
508+
const std::vector<std::string> allowed_inputs,
509+
const std::string &io_name,
510+
const uint32_t index);
505511
TRITONSERVER_Error* ValidateOutputs();
506512
void Execute(
507513
std::vector<TRITONBACKEND_Response*>* responses,
@@ -538,6 +544,7 @@ class ModelInstanceState : public BackendModelInstance {
538544
// Map from configuration name for an input to the index of
539545
// that input in the model.
540546
std::unordered_map<std::string, int> input_index_map_;
547+
uint32_t batch_input_count_;
541548

542549
// Map from configuration name for an output to the index of
543550
// that output in the model.
@@ -607,6 +614,14 @@ ModelInstanceState::ModelInstanceState(
607614
if (model_state->ModelConfig().Find("input", &inputs)) {
608615
expected_input_cnt = inputs.ArraySize();
609616
}
617+
618+
triton::common::TritonJson::Value config_batch_inputs;
619+
if (model_state->ModelConfig().Find("batch_input", &config_batch_inputs)) {
620+
batch_input_count_ = config_batch_inputs.ArraySize();
621+
expected_input_cnt += batch_input_count_;
622+
} else {
623+
batch_input_count_ = 0;
624+
}
610625
}
611626

612627
// If this is a sequence model then make sure that the required
@@ -757,6 +772,38 @@ ModelInstanceState::ValidateTypedSequenceControl(
757772
return nullptr; // success
758773
}
759774

775+
void ModelInstanceState::AddInputToMap(NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
776+
std::string deliminator = "__";
777+
778+
if (is_dict_input_) {
779+
// If dictionary, index is irrelevant but we use the map to store the
780+
// input names since they are the keys for the dictionary
781+
input_index_map_[io_name] = index;
782+
} else {
783+
switch (naming_convention) {
784+
case NamingConvention::FORWARD_ARGUMENT: {
785+
auto itr =
786+
std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name);
787+
if (itr != allowed_inputs.end()) {
788+
input_index_map_[io_name] =
789+
std::distance(allowed_inputs.begin(), itr);
790+
}
791+
return;
792+
}
793+
case NamingConvention::NAMED_INDEX: {
794+
int start_pos = io_name.find(deliminator);
795+
int ip_index = std::atoi(io_name.substr(start_pos + 2).c_str());
796+
input_index_map_[io_name] = ip_index;
797+
return;
798+
}
799+
case NamingConvention::STRICT_CONFIG_ORDERING: {
800+
input_index_map_[io_name] = index;
801+
return;
802+
}
803+
}
804+
}
805+
}
806+
760807
TRITONSERVER_Error*
761808
ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
762809
{
@@ -822,8 +869,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
822869

823870
triton::common::TritonJson::Value ios;
824871
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("input", &ios));
825-
std::string deliminator = "__";
826-
int ip_index = 0;
827872

828873
if (ios.ArraySize() == 0) {
829874
return TRITONSERVER_ErrorNew(
@@ -842,34 +887,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
842887
// Validate name
843888
std::string io_name;
844889
RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
845-
if (is_dict_input_) {
846-
// If dictionary, index is irrelevant but we use the map to store the
847-
// input names since they are the keys for the dictionary
848-
input_index_map_[io_name] = i;
849-
} else {
850-
switch (naming_convention) {
851-
case NamingConvention::FORWARD_ARGUMENT: {
852-
auto itr =
853-
std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name);
854-
if (itr != allowed_inputs.end()) {
855-
input_index_map_[io_name] =
856-
std::distance(allowed_inputs.begin(), itr);
857-
}
858-
break;
859-
}
860-
case NamingConvention::NAMED_INDEX: {
861-
int start_pos = io_name.find(deliminator);
862-
ip_index = std::atoi(io_name.substr(start_pos + 2).c_str());
863-
input_index_map_[io_name] = ip_index;
864-
break;
865-
}
866-
case NamingConvention::STRICT_CONFIG_ORDERING: {
867-
input_index_map_[io_name] = i;
868-
break;
869-
}
870-
}
871-
}
872-
890+
AddInputToMap(naming_convention, allowed_inputs, io_name, i);
873891
// Validate data type
874892
std::string io_dtype;
875893
RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype));
@@ -906,6 +924,16 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
906924
}
907925
}
908926

927+
triton::common::TritonJson::Value batch_inputs;
928+
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
929+
size_t i = 0;
930+
for (const auto& batch_input : StateForModel()->BatchInputs()) {
931+
for (const auto& input_name : batch_input.TargetNames()) {
932+
AddInputToMap(naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
933+
i++;
934+
}
935+
}
936+
909937
return nullptr; // success
910938
}
911939

@@ -1725,7 +1753,8 @@ ModelInstanceState::SetInputTensors(
17251753
// request as the representative for the input tensors.
17261754
uint32_t input_count;
17271755
RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count));
1728-
input_tensors->resize(input_count);
1756+
1757+
input_tensors->resize(input_count + batch_input_count_);
17291758
for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) {
17301759
TRITONBACKEND_Input* input;
17311760
RETURN_IF_ERROR(
@@ -1828,6 +1857,36 @@ ModelInstanceState::SetInputTensors(
18281857
}
18291858
}
18301859

1860+
for (const auto& batch_input : StateForModel()->BatchInputs()) {
1861+
std::vector<int64_t> shape;
1862+
collector->BatchInputShape(batch_input, &shape);
1863+
1864+
for (const auto& input_name : batch_input.TargetNames()) {
1865+
input_names->emplace_back(input_name.c_str());
1866+
1867+
const char* dst_buffer;
1868+
size_t dst_buffer_byte_size;
1869+
TRITONSERVER_MemoryType dst_memory_type;
1870+
int64_t dst_memory_type_id;
1871+
1872+
// Batch inputs are always created on CPU
1873+
RESPOND_ALL_AND_SET_NULL_IF_ERROR(
1874+
(*responses), responses->size(),
1875+
collector->ProcessBatchInput(
1876+
batch_input, nullptr, 0, {{TRITONSERVER_MEMORY_CPU, 0}},
1877+
&dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1878+
&dst_memory_type_id));
1879+
1880+
const auto torch_dtype = ConvertDataTypeToTorchType(batch_input.DataType());
1881+
torch::TensorOptions options{torch_dtype.second};
1882+
auto updated_options = options.device(torch::kCPU);
1883+
1884+
torch::Tensor input_tensor = torch::from_blob(
1885+
const_cast<char*>(dst_buffer), shape, updated_options);
1886+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
1887+
}
1888+
}
1889+
18311890
// Finalize...
18321891
*cuda_copy |= collector->Finalize();
18331892

0 commit comments

Comments
 (0)