25
25
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
#include < stdint.h>
28
+ #include < cstdint>
28
29
#include < exception>
29
30
#include " libtorch_utils.h"
30
31
#include " triton/backend/backend_common.h"
@@ -502,6 +503,11 @@ class ModelInstanceState : public BackendModelInstance {
502
503
triton::common::TritonJson::Value& sequence_batching,
503
504
const std::string& control_kind, bool required, bool * have_control);
504
505
TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
506
+ void AddInputToMap (
507
+ NamingConvention naming_convention,
508
+ const std::vector<std::string> allowed_inputs,
509
+ const std::string &io_name,
510
+ const uint32_t index);
505
511
TRITONSERVER_Error* ValidateOutputs ();
506
512
void Execute (
507
513
std::vector<TRITONBACKEND_Response*>* responses,
@@ -538,6 +544,7 @@ class ModelInstanceState : public BackendModelInstance {
538
544
// Map from configuration name for an input to the index of
539
545
// that input in the model.
540
546
std::unordered_map<std::string, int > input_index_map_;
547
+ uint32_t batch_input_count_ = 0 ;
541
548
542
549
// Map from configuration name for an output to the index of
543
550
// that output in the model.
@@ -607,6 +614,12 @@ ModelInstanceState::ModelInstanceState(
607
614
if (model_state->ModelConfig ().Find (" input" , &inputs)) {
608
615
expected_input_cnt = inputs.ArraySize ();
609
616
}
617
+
618
+ triton::common::TritonJson::Value config_batch_inputs;
619
+ if (model_state->ModelConfig ().Find (" batch_input" , &config_batch_inputs)) {
620
+ batch_input_count_ = config_batch_inputs.ArraySize ();
621
+ expected_input_cnt += batch_input_count_;
622
+ }
610
623
}
611
624
612
625
// If this is a sequence model then make sure that the required
@@ -757,6 +770,38 @@ ModelInstanceState::ValidateTypedSequenceControl(
757
770
return nullptr ; // success
758
771
}
759
772
773
+ void ModelInstanceState::AddInputToMap (NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
774
+ std::string deliminator = " __" ;
775
+
776
+ if (is_dict_input_) {
777
+ // If dictionary, index is irrelevant but we use the map to store the
778
+ // input names since they are the keys for the dictionary
779
+ input_index_map_[io_name] = index;
780
+ } else {
781
+ switch (naming_convention) {
782
+ case NamingConvention::FORWARD_ARGUMENT: {
783
+ auto itr =
784
+ std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
785
+ if (itr != allowed_inputs.end ()) {
786
+ input_index_map_[io_name] =
787
+ std::distance (allowed_inputs.begin (), itr);
788
+ }
789
+ return ;
790
+ }
791
+ case NamingConvention::NAMED_INDEX: {
792
+ int start_pos = io_name.find (deliminator);
793
+ int ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
794
+ input_index_map_[io_name] = ip_index;
795
+ return ;
796
+ }
797
+ case NamingConvention::STRICT_CONFIG_ORDERING: {
798
+ input_index_map_[io_name] = index;
799
+ return ;
800
+ }
801
+ }
802
+ }
803
+ }
804
+
760
805
TRITONSERVER_Error*
761
806
ModelInstanceState::ValidateInputs (const size_t expected_input_cnt)
762
807
{
@@ -822,8 +867,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
822
867
823
868
triton::common::TritonJson::Value ios;
824
869
RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" input" , &ios));
825
- std::string deliminator = " __" ;
826
- int ip_index = 0 ;
827
870
828
871
if (ios.ArraySize () == 0 ) {
829
872
return TRITONSERVER_ErrorNew (
@@ -842,34 +885,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
842
885
// Validate name
843
886
std::string io_name;
844
887
RETURN_IF_ERROR (io.MemberAsString (" name" , &io_name));
845
- if (is_dict_input_) {
846
- // If dictionary, index is irrelevant but we use the map to store the
847
- // input names since they are the keys for the dictionary
848
- input_index_map_[io_name] = i;
849
- } else {
850
- switch (naming_convention) {
851
- case NamingConvention::FORWARD_ARGUMENT: {
852
- auto itr =
853
- std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
854
- if (itr != allowed_inputs.end ()) {
855
- input_index_map_[io_name] =
856
- std::distance (allowed_inputs.begin (), itr);
857
- }
858
- break ;
859
- }
860
- case NamingConvention::NAMED_INDEX: {
861
- int start_pos = io_name.find (deliminator);
862
- ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
863
- input_index_map_[io_name] = ip_index;
864
- break ;
865
- }
866
- case NamingConvention::STRICT_CONFIG_ORDERING: {
867
- input_index_map_[io_name] = i;
868
- break ;
869
- }
870
- }
871
- }
872
-
888
+ AddInputToMap (naming_convention, allowed_inputs, io_name, i);
873
889
// Validate data type
874
890
std::string io_dtype;
875
891
RETURN_IF_ERROR (io.MemberAsString (" data_type" , &io_dtype));
@@ -906,6 +922,16 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
906
922
}
907
923
}
908
924
925
+ triton::common::TritonJson::Value batch_inputs;
926
+ RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
927
+ size_t i = 0 ;
928
+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
929
+ for (const auto & input_name : batch_input.TargetNames ()) {
930
+ AddInputToMap (naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
931
+ i++;
932
+ }
933
+ }
934
+
909
935
return nullptr ; // success
910
936
}
911
937
@@ -1725,7 +1751,8 @@ ModelInstanceState::SetInputTensors(
1725
1751
// request as the representative for the input tensors.
1726
1752
uint32_t input_count;
1727
1753
RETURN_IF_ERROR (TRITONBACKEND_RequestInputCount (requests[0 ], &input_count));
1728
- input_tensors->resize (input_count);
1754
+
1755
+ input_tensors->resize (input_count + batch_input_count_);
1729
1756
for (uint32_t input_idx = 0 ; input_idx < input_count; input_idx++) {
1730
1757
TRITONBACKEND_Input* input;
1731
1758
RETURN_IF_ERROR (
@@ -1828,6 +1855,36 @@ ModelInstanceState::SetInputTensors(
1828
1855
}
1829
1856
}
1830
1857
1858
+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
1859
+ std::vector<int64_t > shape;
1860
+ collector->BatchInputShape (batch_input, &shape);
1861
+
1862
+ for (const auto & input_name : batch_input.TargetNames ()) {
1863
+ input_names->emplace_back (input_name.c_str ());
1864
+
1865
+ const char * dst_buffer;
1866
+ size_t dst_buffer_byte_size;
1867
+ TRITONSERVER_MemoryType dst_memory_type;
1868
+ int64_t dst_memory_type_id;
1869
+
1870
+ // Batch inputs are always created on CPU
1871
+ RESPOND_ALL_AND_SET_NULL_IF_ERROR (
1872
+ (*responses), responses->size (),
1873
+ collector->ProcessBatchInput (
1874
+ batch_input, nullptr , 0 , {{TRITONSERVER_MEMORY_CPU, 0 }},
1875
+ &dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1876
+ &dst_memory_type_id));
1877
+
1878
+ const auto torch_dtype = ConvertDataTypeToTorchType (batch_input.DataType ());
1879
+ torch::TensorOptions options{torch_dtype.second };
1880
+ auto updated_options = options.device (torch::kCPU );
1881
+
1882
+ torch::Tensor input_tensor = torch::from_blob (
1883
+ const_cast <char *>(dst_buffer), shape, updated_options);
1884
+ (*input_tensors)[input_index_map_[input_name]] = input_tensor;
1885
+ }
1886
+ }
1887
+
1831
1888
// Finalize...
1832
1889
*cuda_copy |= collector->Finalize ();
1833
1890
0 commit comments