Skip to content

Commit 5506203

Browse files
committed
Add support for batch_input
1 parent 588c6ac commit 5506203

File tree

1 file changed

+88
-31
lines changed

1 file changed

+88
-31
lines changed

src/libtorch.cc

Lines changed: 88 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
#include <stdint.h>
28+
#include <cstdint>
2829
#include <exception>
2930
#include "libtorch_utils.h"
3031
#include "triton/backend/backend_common.h"
@@ -502,6 +503,11 @@ class ModelInstanceState : public BackendModelInstance {
502503
triton::common::TritonJson::Value& sequence_batching,
503504
const std::string& control_kind, bool required, bool* have_control);
504505
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
506+
void AddInputToMap(
507+
NamingConvention naming_convention,
508+
const std::vector<std::string> allowed_inputs,
509+
const std::string &io_name,
510+
const uint32_t index);
505511
TRITONSERVER_Error* ValidateOutputs();
506512
void Execute(
507513
std::vector<TRITONBACKEND_Response*>* responses,
@@ -538,6 +544,7 @@ class ModelInstanceState : public BackendModelInstance {
538544
// Map from configuration name for an input to the index of
539545
// that input in the model.
540546
std::unordered_map<std::string, int> input_index_map_;
547+
uint32_t batch_input_count_ = 0;
541548

542549
// Map from configuration name for an output to the index of
543550
// that output in the model.
@@ -607,6 +614,12 @@ ModelInstanceState::ModelInstanceState(
607614
if (model_state->ModelConfig().Find("input", &inputs)) {
608615
expected_input_cnt = inputs.ArraySize();
609616
}
617+
618+
triton::common::TritonJson::Value config_batch_inputs;
619+
if (model_state->ModelConfig().Find("batch_input", &config_batch_inputs)) {
620+
batch_input_count_ = config_batch_inputs.ArraySize();
621+
expected_input_cnt += batch_input_count_;
622+
}
610623
}
611624

612625
// If this is a sequence model then make sure that the required
@@ -757,6 +770,38 @@ ModelInstanceState::ValidateTypedSequenceControl(
757770
return nullptr; // success
758771
}
759772

773+
void ModelInstanceState::AddInputToMap(NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
774+
std::string deliminator = "__";
775+
776+
if (is_dict_input_) {
777+
// If dictionary, index is irrelevant but we use the map to store the
778+
// input names since they are the keys for the dictionary
779+
input_index_map_[io_name] = index;
780+
} else {
781+
switch (naming_convention) {
782+
case NamingConvention::FORWARD_ARGUMENT: {
783+
auto itr =
784+
std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name);
785+
if (itr != allowed_inputs.end()) {
786+
input_index_map_[io_name] =
787+
std::distance(allowed_inputs.begin(), itr);
788+
}
789+
return;
790+
}
791+
case NamingConvention::NAMED_INDEX: {
792+
int start_pos = io_name.find(deliminator);
793+
int ip_index = std::atoi(io_name.substr(start_pos + 2).c_str());
794+
input_index_map_[io_name] = ip_index;
795+
return;
796+
}
797+
case NamingConvention::STRICT_CONFIG_ORDERING: {
798+
input_index_map_[io_name] = index;
799+
return;
800+
}
801+
}
802+
}
803+
}
804+
760805
TRITONSERVER_Error*
761806
ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
762807
{
@@ -822,8 +867,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
822867

823868
triton::common::TritonJson::Value ios;
824869
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("input", &ios));
825-
std::string deliminator = "__";
826-
int ip_index = 0;
827870

828871
if (ios.ArraySize() == 0) {
829872
return TRITONSERVER_ErrorNew(
@@ -842,34 +885,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
842885
// Validate name
843886
std::string io_name;
844887
RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
845-
if (is_dict_input_) {
846-
// If dictionary, index is irrelevant but we use the map to store the
847-
// input names since they are the keys for the dictionary
848-
input_index_map_[io_name] = i;
849-
} else {
850-
switch (naming_convention) {
851-
case NamingConvention::FORWARD_ARGUMENT: {
852-
auto itr =
853-
std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name);
854-
if (itr != allowed_inputs.end()) {
855-
input_index_map_[io_name] =
856-
std::distance(allowed_inputs.begin(), itr);
857-
}
858-
break;
859-
}
860-
case NamingConvention::NAMED_INDEX: {
861-
int start_pos = io_name.find(deliminator);
862-
ip_index = std::atoi(io_name.substr(start_pos + 2).c_str());
863-
input_index_map_[io_name] = ip_index;
864-
break;
865-
}
866-
case NamingConvention::STRICT_CONFIG_ORDERING: {
867-
input_index_map_[io_name] = i;
868-
break;
869-
}
870-
}
871-
}
872-
888+
AddInputToMap(naming_convention, allowed_inputs, io_name, i);
873889
// Validate data type
874890
std::string io_dtype;
875891
RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype));
@@ -906,6 +922,16 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
906922
}
907923
}
908924

925+
triton::common::TritonJson::Value batch_inputs;
926+
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
927+
size_t i = 0;
928+
for (const auto& batch_input : StateForModel()->BatchInputs()) {
929+
for (const auto& input_name : batch_input.TargetNames()) {
930+
AddInputToMap(naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
931+
i++;
932+
}
933+
}
934+
909935
return nullptr; // success
910936
}
911937

@@ -1725,7 +1751,8 @@ ModelInstanceState::SetInputTensors(
17251751
// request as the representative for the input tensors.
17261752
uint32_t input_count;
17271753
RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count));
1728-
input_tensors->resize(input_count);
1754+
1755+
input_tensors->resize(input_count + batch_input_count_);
17291756
for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) {
17301757
TRITONBACKEND_Input* input;
17311758
RETURN_IF_ERROR(
@@ -1828,6 +1855,36 @@ ModelInstanceState::SetInputTensors(
18281855
}
18291856
}
18301857

1858+
for (const auto& batch_input : StateForModel()->BatchInputs()) {
1859+
std::vector<int64_t> shape;
1860+
collector->BatchInputShape(batch_input, &shape);
1861+
1862+
for (const auto& input_name : batch_input.TargetNames()) {
1863+
input_names->emplace_back(input_name.c_str());
1864+
1865+
const char* dst_buffer;
1866+
size_t dst_buffer_byte_size;
1867+
TRITONSERVER_MemoryType dst_memory_type;
1868+
int64_t dst_memory_type_id;
1869+
1870+
// Batch inputs are always created on CPU
1871+
RESPOND_ALL_AND_SET_NULL_IF_ERROR(
1872+
(*responses), responses->size(),
1873+
collector->ProcessBatchInput(
1874+
batch_input, nullptr, 0, {{TRITONSERVER_MEMORY_CPU, 0}},
1875+
&dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1876+
&dst_memory_type_id));
1877+
1878+
const auto torch_dtype = ConvertDataTypeToTorchType(batch_input.DataType());
1879+
torch::TensorOptions options{torch_dtype.second};
1880+
auto updated_options = options.device(torch::kCPU);
1881+
1882+
torch::Tensor input_tensor = torch::from_blob(
1883+
const_cast<char*>(dst_buffer), shape, updated_options);
1884+
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
1885+
}
1886+
}
1887+
18311888
// Finalize...
18321889
*cuda_copy |= collector->Finalize();
18331890

0 commit comments

Comments
 (0)