25
25
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
#include < stdint.h>
28
+
29
+ #include < cstdint>
28
30
#include < exception>
31
+
29
32
#include " libtorch_utils.h"
30
33
#include " triton/backend/backend_common.h"
31
34
#include " triton/backend/backend_input_collector.h"
@@ -502,6 +505,10 @@ class ModelInstanceState : public BackendModelInstance {
502
505
triton::common::TritonJson::Value& sequence_batching,
503
506
const std::string& control_kind, bool required, bool * have_control);
504
507
TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
508
+ void AddInputToMap (
509
+ NamingConvention naming_convention,
510
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
511
+ const uint32_t index);
505
512
TRITONSERVER_Error* ValidateOutputs ();
506
513
void Execute (
507
514
std::vector<TRITONBACKEND_Response*>* responses,
@@ -538,6 +545,7 @@ class ModelInstanceState : public BackendModelInstance {
538
545
// Map from configuration name for an input to the index of
539
546
// that input in the model.
540
547
std::unordered_map<std::string, int > input_index_map_;
548
+ uint32_t batch_input_count_ = 0 ;
541
549
542
550
// Map from configuration name for an output to the index of
543
551
// that output in the model.
@@ -607,6 +615,12 @@ ModelInstanceState::ModelInstanceState(
607
615
if (model_state->ModelConfig ().Find (" input" , &inputs)) {
608
616
expected_input_cnt = inputs.ArraySize ();
609
617
}
618
+
619
+ triton::common::TritonJson::Value config_batch_inputs;
620
+ if (model_state->ModelConfig ().Find (" batch_input" , &config_batch_inputs)) {
621
+ batch_input_count_ = config_batch_inputs.ArraySize ();
622
+ expected_input_cnt += batch_input_count_;
623
+ }
610
624
}
611
625
612
626
// If this is a sequence model then make sure that the required
@@ -757,6 +771,43 @@ ModelInstanceState::ValidateTypedSequenceControl(
757
771
return nullptr ; // success
758
772
}
759
773
774
+ void
775
+ ModelInstanceState::AddInputToMap (
776
+ NamingConvention naming_convention,
777
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
778
+ const uint32_t index)
779
+ {
780
+ std::string deliminator = " __" ;
781
+
782
+ if (is_dict_input_) {
783
+ // If dictionary, index is irrelevant but we use the map to store the
784
+ // input names since they are the keys for the dictionary
785
+ input_index_map_[io_name] = index;
786
+ } else {
787
+ switch (naming_convention) {
788
+ case NamingConvention::FORWARD_ARGUMENT: {
789
+ auto itr =
790
+ std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
791
+ if (itr != allowed_inputs.end ()) {
792
+ input_index_map_[io_name] =
793
+ std::distance (allowed_inputs.begin (), itr);
794
+ }
795
+ return ;
796
+ }
797
+ case NamingConvention::NAMED_INDEX: {
798
+ int start_pos = io_name.find (deliminator);
799
+ int ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
800
+ input_index_map_[io_name] = ip_index;
801
+ return ;
802
+ }
803
+ case NamingConvention::STRICT_CONFIG_ORDERING: {
804
+ input_index_map_[io_name] = index;
805
+ return ;
806
+ }
807
+ }
808
+ }
809
+ }
810
+
760
811
TRITONSERVER_Error*
761
812
ModelInstanceState::ValidateInputs (const size_t expected_input_cnt)
762
813
{
@@ -822,8 +873,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
822
873
823
874
triton::common::TritonJson::Value ios;
824
875
RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" input" , &ios));
825
- std::string deliminator = " __" ;
826
- int ip_index = 0 ;
827
876
828
877
if (ios.ArraySize () == 0 ) {
829
878
return TRITONSERVER_ErrorNew (
@@ -842,34 +891,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
842
891
// Validate name
843
892
std::string io_name;
844
893
RETURN_IF_ERROR (io.MemberAsString (" name" , &io_name));
845
- if (is_dict_input_) {
846
- // If dictionary, index is irrelevant but we use the map to store the
847
- // input names since they are the keys for the dictionary
848
- input_index_map_[io_name] = i;
849
- } else {
850
- switch (naming_convention) {
851
- case NamingConvention::FORWARD_ARGUMENT: {
852
- auto itr =
853
- std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
854
- if (itr != allowed_inputs.end ()) {
855
- input_index_map_[io_name] =
856
- std::distance (allowed_inputs.begin (), itr);
857
- }
858
- break ;
859
- }
860
- case NamingConvention::NAMED_INDEX: {
861
- int start_pos = io_name.find (deliminator);
862
- ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
863
- input_index_map_[io_name] = ip_index;
864
- break ;
865
- }
866
- case NamingConvention::STRICT_CONFIG_ORDERING: {
867
- input_index_map_[io_name] = i;
868
- break ;
869
- }
870
- }
871
- }
872
-
894
+ AddInputToMap (naming_convention, allowed_inputs, io_name, i);
873
895
// Validate data type
874
896
std::string io_dtype;
875
897
RETURN_IF_ERROR (io.MemberAsString (" data_type" , &io_dtype));
@@ -906,6 +928,18 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
906
928
}
907
929
}
908
930
931
+ triton::common::TritonJson::Value batch_inputs;
932
+ RETURN_IF_ERROR (
933
+ model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
934
+ size_t i = 0 ;
935
+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
936
+ for (const auto & input_name : batch_input.TargetNames ()) {
937
+ AddInputToMap (
938
+ naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
939
+ i++;
940
+ }
941
+ }
942
+
909
943
return nullptr ; // success
910
944
}
911
945
@@ -1725,7 +1759,8 @@ ModelInstanceState::SetInputTensors(
1725
1759
// request as the representative for the input tensors.
1726
1760
uint32_t input_count;
1727
1761
RETURN_IF_ERROR (TRITONBACKEND_RequestInputCount (requests[0 ], &input_count));
1728
- input_tensors->resize (input_count);
1762
+
1763
+ input_tensors->resize (input_count + batch_input_count_);
1729
1764
for (uint32_t input_idx = 0 ; input_idx < input_count; input_idx++) {
1730
1765
TRITONBACKEND_Input* input;
1731
1766
RETURN_IF_ERROR (
@@ -1828,6 +1863,36 @@ ModelInstanceState::SetInputTensors(
1828
1863
}
1829
1864
}
1830
1865
1866
+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
1867
+ std::vector<int64_t > shape;
1868
+ collector->BatchInputShape (batch_input, &shape);
1869
+
1870
+ for (const auto & input_name : batch_input.TargetNames ()) {
1871
+ input_names->emplace_back (input_name.c_str ());
1872
+
1873
+ const char * dst_buffer;
1874
+ size_t dst_buffer_byte_size;
1875
+ TRITONSERVER_MemoryType dst_memory_type;
1876
+ int64_t dst_memory_type_id;
1877
+
1878
+ // Batch inputs are always created on CPU
1879
+ RESPOND_ALL_AND_SET_NULL_IF_ERROR (
1880
+ (*responses), responses->size (),
1881
+ collector->ProcessBatchInput (
1882
+ batch_input, nullptr , 0 , {{TRITONSERVER_MEMORY_CPU, 0 }},
1883
+ &dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1884
+ &dst_memory_type_id));
1885
+
1886
+ const auto torch_dtype =
1887
+ ConvertDataTypeToTorchType (batch_input.DataType ());
1888
+
1889
+ torch::Tensor input_tensor = torch::from_blob (
1890
+ const_cast <char *>(dst_buffer), shape,
1891
+ updated_options.dtype (torch_dtype.second ));
1892
+ (*input_tensors)[input_index_map_[input_name]] = input_tensor;
1893
+ }
1894
+ }
1895
+
1831
1896
// Finalize...
1832
1897
*cuda_copy |= collector->Finalize ();
1833
1898
0 commit comments