@@ -505,9 +505,8 @@ class ModelInstanceState : public BackendModelInstance {
505
505
const std::string& control_kind, bool required, bool * have_control);
506
506
TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
507
507
void AddInputToMap (
508
- NamingConvention naming_convention,
509
- const std::vector<std::string> allowed_inputs,
510
- const std::string &io_name,
508
+ NamingConvention naming_convention,
509
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
511
510
const uint32_t index);
512
511
TRITONSERVER_Error* ValidateOutputs ();
513
512
void Execute (
@@ -771,7 +770,12 @@ ModelInstanceState::ValidateTypedSequenceControl(
771
770
return nullptr ; // success
772
771
}
773
772
774
- void ModelInstanceState::AddInputToMap (NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
773
+ void
774
+ ModelInstanceState::AddInputToMap (
775
+ NamingConvention naming_convention,
776
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
777
+ const uint32_t index)
778
+ {
775
779
std::string deliminator = " __" ;
776
780
777
781
if (is_dict_input_) {
@@ -924,11 +928,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
924
928
}
925
929
926
930
triton::common::TritonJson::Value batch_inputs;
927
- RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
931
+ RETURN_IF_ERROR (
932
+ model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
928
933
size_t i = 0 ;
929
934
for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
930
935
for (const auto & input_name : batch_input.TargetNames ()) {
931
- AddInputToMap (naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
936
+ AddInputToMap (
937
+ naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
932
938
i++;
933
939
}
934
940
}
@@ -1754,6 +1760,16 @@ ModelInstanceState::SetInputTensors(
1754
1760
RETURN_IF_ERROR (TRITONBACKEND_RequestInputCount (requests[0 ], &input_count));
1755
1761
1756
1762
input_tensors->resize (input_count + batch_input_count_);
1763
+
1764
+ // The inputs must be in contiguous CPU/GPU memory.
1765
+ std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1766
+ if (device_.is_cpu ()) {
1767
+ alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1768
+ {TRITONSERVER_MEMORY_CPU, 0 }};
1769
+ } else {
1770
+ alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1771
+ }
1772
+
1757
1773
for (uint32_t input_idx = 0 ; input_idx < input_count; input_idx++) {
1758
1774
TRITONBACKEND_Input* input;
1759
1775
RETURN_IF_ERROR (
@@ -1797,15 +1813,6 @@ ModelInstanceState::SetInputTensors(
1797
1813
}
1798
1814
}
1799
1815
1800
- // The input must be in contiguous CPU/GPU memory.
1801
- std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1802
- if (device_.is_cpu ()) {
1803
- alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1804
- {TRITONSERVER_MEMORY_CPU, 0 }};
1805
- } else {
1806
- alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1807
- }
1808
-
1809
1816
const char * input_buffer;
1810
1817
size_t batchn_byte_size;
1811
1818
TRITONSERVER_MemoryType memory_type;
@@ -1868,15 +1875,14 @@ ModelInstanceState::SetInputTensors(
1868
1875
TRITONSERVER_MemoryType dst_memory_type;
1869
1876
int64_t dst_memory_type_id;
1870
1877
1871
- // Batch inputs are always created on CPU
1872
1878
RESPOND_ALL_AND_SET_NULL_IF_ERROR (
1873
1879
(*responses), responses->size (),
1874
1880
collector->ProcessBatchInput (
1875
- batch_input, nullptr , 0 , {{TRITONSERVER_MEMORY_CPU, 0 }},
1876
- &dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1877
- &dst_memory_type_id));
1881
+ batch_input, nullptr , 0 , alloc_perference, &dst_buffer,
1882
+ &dst_buffer_byte_size, &dst_memory_type, &dst_memory_type_id));
1878
1883
1879
- const auto torch_dtype = ConvertDataTypeToTorchType (batch_input.DataType ());
1884
+ const auto torch_dtype =
1885
+ ConvertDataTypeToTorchType (batch_input.DataType ());
1880
1886
torch::TensorOptions options{torch_dtype.second };
1881
1887
auto updated_options = options.device (torch::kCPU );
1882
1888
0 commit comments