25
25
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
#include < stdint.h>
28
+
29
+ #include < cstdint>
28
30
#include < exception>
31
+
29
32
#include " libtorch_utils.h"
30
33
#include " triton/backend/backend_common.h"
31
34
#include " triton/backend/backend_input_collector.h"
@@ -502,6 +505,10 @@ class ModelInstanceState : public BackendModelInstance {
502
505
triton::common::TritonJson::Value& sequence_batching,
503
506
const std::string& control_kind, bool required, bool * have_control);
504
507
TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
508
+ void AddInputToMap (
509
+ NamingConvention naming_convention,
510
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
511
+ const uint32_t index);
505
512
TRITONSERVER_Error* ValidateOutputs ();
506
513
void Execute (
507
514
std::vector<TRITONBACKEND_Response*>* responses,
@@ -538,6 +545,7 @@ class ModelInstanceState : public BackendModelInstance {
538
545
// Map from configuration name for an input to the index of
539
546
// that input in the model.
540
547
std::unordered_map<std::string, int > input_index_map_;
548
+ uint32_t batch_input_count_ = 0 ;
541
549
542
550
// Map from configuration name for an output to the index of
543
551
// that output in the model.
@@ -607,6 +615,12 @@ ModelInstanceState::ModelInstanceState(
607
615
if (model_state->ModelConfig ().Find (" input" , &inputs)) {
608
616
expected_input_cnt = inputs.ArraySize ();
609
617
}
618
+
619
+ triton::common::TritonJson::Value config_batch_inputs;
620
+ if (model_state->ModelConfig ().Find (" batch_input" , &config_batch_inputs)) {
621
+ batch_input_count_ = config_batch_inputs.ArraySize ();
622
+ expected_input_cnt += batch_input_count_;
623
+ }
610
624
}
611
625
612
626
// If this is a sequence model then make sure that the required
@@ -757,6 +771,43 @@ ModelInstanceState::ValidateTypedSequenceControl(
757
771
return nullptr ; // success
758
772
}
759
773
774
+ void
775
+ ModelInstanceState::AddInputToMap (
776
+ NamingConvention naming_convention,
777
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
778
+ const uint32_t index)
779
+ {
780
+ std::string deliminator = " __" ;
781
+
782
+ if (is_dict_input_) {
783
+ // If dictionary, index is irrelevant but we use the map to store the
784
+ // input names since they are the keys for the dictionary
785
+ input_index_map_[io_name] = index;
786
+ } else {
787
+ switch (naming_convention) {
788
+ case NamingConvention::FORWARD_ARGUMENT: {
789
+ auto itr =
790
+ std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
791
+ if (itr != allowed_inputs.end ()) {
792
+ input_index_map_[io_name] =
793
+ std::distance (allowed_inputs.begin (), itr);
794
+ }
795
+ return ;
796
+ }
797
+ case NamingConvention::NAMED_INDEX: {
798
+ int start_pos = io_name.find (deliminator);
799
+ int ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
800
+ input_index_map_[io_name] = ip_index;
801
+ return ;
802
+ }
803
+ case NamingConvention::STRICT_CONFIG_ORDERING: {
804
+ input_index_map_[io_name] = index;
805
+ return ;
806
+ }
807
+ }
808
+ }
809
+ }
810
+
760
811
TRITONSERVER_Error*
761
812
ModelInstanceState::ValidateInputs (const size_t expected_input_cnt)
762
813
{
@@ -822,8 +873,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
822
873
823
874
triton::common::TritonJson::Value ios;
824
875
RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" input" , &ios));
825
- std::string deliminator = " __" ;
826
- int ip_index = 0 ;
827
876
828
877
if (ios.ArraySize () == 0 ) {
829
878
return TRITONSERVER_ErrorNew (
@@ -842,34 +891,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
842
891
// Validate name
843
892
std::string io_name;
844
893
RETURN_IF_ERROR (io.MemberAsString (" name" , &io_name));
845
- if (is_dict_input_) {
846
- // If dictionary, index is irrelevant but we use the map to store the
847
- // input names since they are the keys for the dictionary
848
- input_index_map_[io_name] = i;
849
- } else {
850
- switch (naming_convention) {
851
- case NamingConvention::FORWARD_ARGUMENT: {
852
- auto itr =
853
- std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
854
- if (itr != allowed_inputs.end ()) {
855
- input_index_map_[io_name] =
856
- std::distance (allowed_inputs.begin (), itr);
857
- }
858
- break ;
859
- }
860
- case NamingConvention::NAMED_INDEX: {
861
- int start_pos = io_name.find (deliminator);
862
- ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
863
- input_index_map_[io_name] = ip_index;
864
- break ;
865
- }
866
- case NamingConvention::STRICT_CONFIG_ORDERING: {
867
- input_index_map_[io_name] = i;
868
- break ;
869
- }
870
- }
871
- }
872
-
894
+ AddInputToMap (naming_convention, allowed_inputs, io_name, i);
873
895
// Validate data type
874
896
std::string io_dtype;
875
897
RETURN_IF_ERROR (io.MemberAsString (" data_type" , &io_dtype));
@@ -906,6 +928,18 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
906
928
}
907
929
}
908
930
931
+ triton::common::TritonJson::Value batch_inputs;
932
+ RETURN_IF_ERROR (
933
+ model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
934
+ size_t i = 0 ;
935
+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
936
+ for (const auto & input_name : batch_input.TargetNames ()) {
937
+ AddInputToMap (
938
+ naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
939
+ i++;
940
+ }
941
+ }
942
+
909
943
return nullptr ; // success
910
944
}
911
945
@@ -1312,12 +1346,12 @@ ModelInstanceState::Execute(
1312
1346
torch::jit::overrideCanFuseOnCPU (false );
1313
1347
torch::jit::overrideCanFuseOnGPU (false );
1314
1348
torch::jit::setTensorExprFuserEnabled (false );
1315
- torch::jit::fuser::cuda::setEnabled (true );
1349
+ torch::jit::fuser::cuda::setEnabled (true );
1316
1350
} else {
1317
1351
torch::jit::overrideCanFuseOnCPU (true );
1318
1352
torch::jit::overrideCanFuseOnGPU (true );
1319
1353
torch::jit::setTensorExprFuserEnabled (true );
1320
- torch::jit::fuser::cuda::setEnabled (false );
1354
+ torch::jit::fuser::cuda::setEnabled (false );
1321
1355
}
1322
1356
}
1323
1357
@@ -1725,7 +1759,8 @@ ModelInstanceState::SetInputTensors(
1725
1759
// request as the representative for the input tensors.
1726
1760
uint32_t input_count;
1727
1761
RETURN_IF_ERROR (TRITONBACKEND_RequestInputCount (requests[0 ], &input_count));
1728
- input_tensors->resize (input_count);
1762
+
1763
+ input_tensors->resize (input_count + batch_input_count_);
1729
1764
for (uint32_t input_idx = 0 ; input_idx < input_count; input_idx++) {
1730
1765
TRITONBACKEND_Input* input;
1731
1766
RETURN_IF_ERROR (
@@ -1761,9 +1796,9 @@ ModelInstanceState::SetInputTensors(
1761
1796
1762
1797
batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
1763
1798
}
1764
- }
1765
- else {
1766
- batchn_shape = std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1799
+ } else {
1800
+ batchn_shape =
1801
+ std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1767
1802
if (supports_batching_) {
1768
1803
batchn_shape[0 ] = total_batch_size;
1769
1804
}
@@ -1772,8 +1807,8 @@ ModelInstanceState::SetInputTensors(
1772
1807
// The input must be in contiguous CPU/GPU memory.
1773
1808
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1774
1809
if (device_.is_cpu ()) {
1775
- alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1776
- {TRITONSERVER_MEMORY_CPU, 0 }};
1810
+ alloc_perference = {
1811
+ {TRITONSERVER_MEMORY_CPU_PINNED, 0 }, {TRITONSERVER_MEMORY_CPU, 0 }};
1777
1812
} else {
1778
1813
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1779
1814
}
@@ -1828,6 +1863,36 @@ ModelInstanceState::SetInputTensors(
1828
1863
}
1829
1864
}
1830
1865
1866
+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
1867
+ std::vector<int64_t > shape;
1868
+ collector->BatchInputShape (batch_input, &shape);
1869
+
1870
+ for (const auto & input_name : batch_input.TargetNames ()) {
1871
+ input_names->emplace_back (input_name.c_str ());
1872
+
1873
+ const char * dst_buffer;
1874
+ size_t dst_buffer_byte_size;
1875
+ TRITONSERVER_MemoryType dst_memory_type;
1876
+ int64_t dst_memory_type_id;
1877
+
1878
+ // Batch inputs are always created on CPU
1879
+ RESPOND_ALL_AND_SET_NULL_IF_ERROR (
1880
+ (*responses), responses->size (),
1881
+ collector->ProcessBatchInput (
1882
+ batch_input, nullptr , 0 , {{TRITONSERVER_MEMORY_CPU, 0 }},
1883
+ &dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1884
+ &dst_memory_type_id));
1885
+
1886
+ const auto torch_dtype =
1887
+ ConvertDataTypeToTorchType (batch_input.DataType ());
1888
+
1889
+ torch::Tensor input_tensor = torch::from_blob (
1890
+ const_cast <char *>(dst_buffer), shape,
1891
+ updated_options.dtype (torch_dtype.second ));
1892
+ (*input_tensors)[input_index_map_[input_name]] = input_tensor;
1893
+ }
1894
+ }
1895
+
1831
1896
// Finalize...
1832
1897
*cuda_copy |= collector->Finalize ();
1833
1898
@@ -1887,9 +1952,11 @@ ModelInstanceState::ReadOutputTensors(
1887
1952
1888
1953
// Output tensors may not reside on the same device as model
1889
1954
torch::Device tensor_device = output_flat.device ();
1890
- const auto memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1891
- : TRITONSERVER_MEMORY_GPU;
1892
- const auto memory_id = (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1955
+ const auto memory_type = (tensor_device.type () == torch::kCPU )
1956
+ ? TRITONSERVER_MEMORY_CPU
1957
+ : TRITONSERVER_MEMORY_GPU;
1958
+ const auto memory_id =
1959
+ (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1893
1960
1894
1961
// Batch output doesn't support string data type yet, as it is not trivial
1895
1962
// to parse string output
@@ -1906,16 +1973,16 @@ ModelInstanceState::ReadOutputTensors(
1906
1973
return TRITONSERVER_ErrorNew (
1907
1974
TRITONSERVER_ERROR_INVALID_ARG,
1908
1975
(std::string (" output '" ) + name +
1909
- " ' is a scalar which is not supported." )
1976
+ " ' is a scalar which is not supported." )
1910
1977
.c_str ());
1911
1978
}
1912
1979
1913
1980
responder.ProcessTensor (
1914
- name, output_dtype, batchn_shape, output_buffer,
1915
- memory_type, memory_id);
1981
+ name, output_dtype, batchn_shape, output_buffer, memory_type,
1982
+ memory_id);
1916
1983
} else {
1917
1984
responder.ProcessBatchOutput (
1918
- name, *batch_output, output_buffer, memory_type, memory_id);
1985
+ name, *batch_output, output_buffer, memory_type, memory_id);
1919
1986
}
1920
1987
} else if (output_tensors[op_index].isList ()) {
1921
1988
// Custom handling for string/bytes tensor...
0 commit comments