25
25
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
#include < stdint.h>
28
+ #include < cstdint>
28
29
#include < exception>
29
30
#include " libtorch_utils.h"
30
31
#include " triton/backend/backend_common.h"
@@ -502,6 +503,11 @@ class ModelInstanceState : public BackendModelInstance {
502
503
triton::common::TritonJson::Value& sequence_batching,
503
504
const std::string& control_kind, bool required, bool * have_control);
504
505
TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
506
+ void AddInputToMap (
507
+ NamingConvention naming_convention,
508
+ const std::vector<std::string> allowed_inputs,
509
+ const std::string &io_name,
510
+ const uint32_t index);
505
511
TRITONSERVER_Error* ValidateOutputs ();
506
512
void Execute (
507
513
std::vector<TRITONBACKEND_Response*>* responses,
@@ -538,6 +544,7 @@ class ModelInstanceState : public BackendModelInstance {
538
544
// Map from configuration name for an input to the index of
539
545
// that input in the model.
540
546
std::unordered_map<std::string, int > input_index_map_;
547
+ uint32_t batch_input_count_;
541
548
542
549
// Map from configuration name for an output to the index of
543
550
// that output in the model.
@@ -607,6 +614,14 @@ ModelInstanceState::ModelInstanceState(
607
614
if (model_state->ModelConfig ().Find (" input" , &inputs)) {
608
615
expected_input_cnt = inputs.ArraySize ();
609
616
}
617
+
618
+ triton::common::TritonJson::Value config_batch_inputs;
619
+ if (model_state->ModelConfig ().Find (" batch_input" , &config_batch_inputs)) {
620
+ batch_input_count_ = config_batch_inputs.ArraySize ();
621
+ expected_input_cnt += batch_input_count_;
622
+ } else {
623
+ batch_input_count_ = 0 ;
624
+ }
610
625
}
611
626
612
627
// If this is a sequence model then make sure that the required
@@ -757,6 +772,38 @@ ModelInstanceState::ValidateTypedSequenceControl(
757
772
return nullptr ; // success
758
773
}
759
774
775
+ void ModelInstanceState::AddInputToMap (NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
776
+ std::string deliminator = " __" ;
777
+
778
+ if (is_dict_input_) {
779
+ // If dictionary, index is irrelevant but we use the map to store the
780
+ // input names since they are the keys for the dictionary
781
+ input_index_map_[io_name] = index;
782
+ } else {
783
+ switch (naming_convention) {
784
+ case NamingConvention::FORWARD_ARGUMENT: {
785
+ auto itr =
786
+ std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
787
+ if (itr != allowed_inputs.end ()) {
788
+ input_index_map_[io_name] =
789
+ std::distance (allowed_inputs.begin (), itr);
790
+ }
791
+ return ;
792
+ }
793
+ case NamingConvention::NAMED_INDEX: {
794
+ int start_pos = io_name.find (deliminator);
795
+ int ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
796
+ input_index_map_[io_name] = ip_index;
797
+ return ;
798
+ }
799
+ case NamingConvention::STRICT_CONFIG_ORDERING: {
800
+ input_index_map_[io_name] = index;
801
+ return ;
802
+ }
803
+ }
804
+ }
805
+ }
806
+
760
807
TRITONSERVER_Error*
761
808
ModelInstanceState::ValidateInputs (const size_t expected_input_cnt)
762
809
{
@@ -822,8 +869,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
822
869
823
870
triton::common::TritonJson::Value ios;
824
871
RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" input" , &ios));
825
- std::string deliminator = " __" ;
826
- int ip_index = 0 ;
827
872
828
873
if (ios.ArraySize () == 0 ) {
829
874
return TRITONSERVER_ErrorNew (
@@ -842,34 +887,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
842
887
// Validate name
843
888
std::string io_name;
844
889
RETURN_IF_ERROR (io.MemberAsString (" name" , &io_name));
845
- if (is_dict_input_) {
846
- // If dictionary, index is irrelevant but we use the map to store the
847
- // input names since they are the keys for the dictionary
848
- input_index_map_[io_name] = i;
849
- } else {
850
- switch (naming_convention) {
851
- case NamingConvention::FORWARD_ARGUMENT: {
852
- auto itr =
853
- std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
854
- if (itr != allowed_inputs.end ()) {
855
- input_index_map_[io_name] =
856
- std::distance (allowed_inputs.begin (), itr);
857
- }
858
- break ;
859
- }
860
- case NamingConvention::NAMED_INDEX: {
861
- int start_pos = io_name.find (deliminator);
862
- ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
863
- input_index_map_[io_name] = ip_index;
864
- break ;
865
- }
866
- case NamingConvention::STRICT_CONFIG_ORDERING: {
867
- input_index_map_[io_name] = i;
868
- break ;
869
- }
870
- }
871
- }
872
-
890
+ AddInputToMap (naming_convention, allowed_inputs, io_name, i);
873
891
// Validate data type
874
892
std::string io_dtype;
875
893
RETURN_IF_ERROR (io.MemberAsString (" data_type" , &io_dtype));
@@ -906,6 +924,16 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
906
924
}
907
925
}
908
926
927
+ triton::common::TritonJson::Value batch_inputs;
928
+ RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
929
+ size_t i = 0 ;
930
+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
931
+ for (const auto & input_name : batch_input.TargetNames ()) {
932
+ AddInputToMap (naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
933
+ i++;
934
+ }
935
+ }
936
+
909
937
return nullptr ; // success
910
938
}
911
939
@@ -1725,7 +1753,8 @@ ModelInstanceState::SetInputTensors(
1725
1753
// request as the representative for the input tensors.
1726
1754
uint32_t input_count;
1727
1755
RETURN_IF_ERROR (TRITONBACKEND_RequestInputCount (requests[0 ], &input_count));
1728
- input_tensors->resize (input_count);
1756
+
1757
+ input_tensors->resize (input_count + batch_input_count_);
1729
1758
for (uint32_t input_idx = 0 ; input_idx < input_count; input_idx++) {
1730
1759
TRITONBACKEND_Input* input;
1731
1760
RETURN_IF_ERROR (
@@ -1828,6 +1857,36 @@ ModelInstanceState::SetInputTensors(
1828
1857
}
1829
1858
}
1830
1859
1860
+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
1861
+ std::vector<int64_t > shape;
1862
+ collector->BatchInputShape (batch_input, &shape);
1863
+
1864
+ for (const auto & input_name : batch_input.TargetNames ()) {
1865
+ input_names->emplace_back (input_name.c_str ());
1866
+
1867
+ const char * dst_buffer;
1868
+ size_t dst_buffer_byte_size;
1869
+ TRITONSERVER_MemoryType dst_memory_type;
1870
+ int64_t dst_memory_type_id;
1871
+
1872
+ // Batch inputs are always created on CPU
1873
+ RESPOND_ALL_AND_SET_NULL_IF_ERROR (
1874
+ (*responses), responses->size (),
1875
+ collector->ProcessBatchInput (
1876
+ batch_input, nullptr , 0 , {{TRITONSERVER_MEMORY_CPU, 0 }},
1877
+ &dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1878
+ &dst_memory_type_id));
1879
+
1880
+ const auto torch_dtype = ConvertDataTypeToTorchType (batch_input.DataType ());
1881
+ torch::TensorOptions options{torch_dtype.second };
1882
+ auto updated_options = options.device (torch::kCPU );
1883
+
1884
+ torch::Tensor input_tensor = torch::from_blob (
1885
+ const_cast <char *>(dst_buffer), shape, updated_options);
1886
+ (*input_tensors)[input_index_map_[input_name]] = input_tensor;
1887
+ }
1888
+ }
1889
+
1831
1890
// Finalize...
1832
1891
*cuda_copy |= collector->Finalize ();
1833
1892
0 commit comments