@@ -504,9 +504,8 @@ class ModelInstanceState : public BackendModelInstance {
504
504
const std::string& control_kind, bool required, bool * have_control);
505
505
TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
506
506
void AddInputToMap (
507
- NamingConvention naming_convention,
508
- const std::vector<std::string> allowed_inputs,
509
- const std::string &io_name,
507
+ NamingConvention naming_convention,
508
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
510
509
const uint32_t index);
511
510
TRITONSERVER_Error* ValidateOutputs ();
512
511
void Execute (
@@ -770,7 +769,12 @@ ModelInstanceState::ValidateTypedSequenceControl(
770
769
return nullptr ; // success
771
770
}
772
771
773
- void ModelInstanceState::AddInputToMap (NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
772
+ void
773
+ ModelInstanceState::AddInputToMap (
774
+ NamingConvention naming_convention,
775
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
776
+ const uint32_t index)
777
+ {
774
778
std::string deliminator = " __" ;
775
779
776
780
if (is_dict_input_) {
@@ -923,11 +927,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
923
927
}
924
928
925
929
triton::common::TritonJson::Value batch_inputs;
926
- RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
930
+ RETURN_IF_ERROR (
931
+ model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
927
932
size_t i = 0 ;
928
933
for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
929
934
for (const auto & input_name : batch_input.TargetNames ()) {
930
- AddInputToMap (naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
935
+ AddInputToMap (
936
+ naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
931
937
i++;
932
938
}
933
939
}
@@ -1338,12 +1344,12 @@ ModelInstanceState::Execute(
1338
1344
torch::jit::overrideCanFuseOnCPU (false );
1339
1345
torch::jit::overrideCanFuseOnGPU (false );
1340
1346
torch::jit::setTensorExprFuserEnabled (false );
1341
- torch::jit::fuser::cuda::setEnabled (true );
1347
+ torch::jit::fuser::cuda::setEnabled (true );
1342
1348
} else {
1343
1349
torch::jit::overrideCanFuseOnCPU (true );
1344
1350
torch::jit::overrideCanFuseOnGPU (true );
1345
1351
torch::jit::setTensorExprFuserEnabled (true );
1346
- torch::jit::fuser::cuda::setEnabled (false );
1352
+ torch::jit::fuser::cuda::setEnabled (false );
1347
1353
}
1348
1354
}
1349
1355
@@ -1753,6 +1759,16 @@ ModelInstanceState::SetInputTensors(
1753
1759
RETURN_IF_ERROR (TRITONBACKEND_RequestInputCount (requests[0 ], &input_count));
1754
1760
1755
1761
input_tensors->resize (input_count + batch_input_count_);
1762
+
1763
+ // The inputs must be in contiguous CPU/GPU memory.
1764
+ std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1765
+ if (device_.is_cpu ()) {
1766
+ alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1767
+ {TRITONSERVER_MEMORY_CPU, 0 }};
1768
+ } else {
1769
+ alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1770
+ }
1771
+
1756
1772
for (uint32_t input_idx = 0 ; input_idx < input_count; input_idx++) {
1757
1773
TRITONBACKEND_Input* input;
1758
1774
RETURN_IF_ERROR (
@@ -1788,23 +1804,14 @@ ModelInstanceState::SetInputTensors(
1788
1804
1789
1805
batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
1790
1806
}
1791
- }
1792
- else {
1793
- batchn_shape = std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1807
+ } else {
1808
+ batchn_shape =
1809
+ std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1794
1810
if (supports_batching_) {
1795
1811
batchn_shape[0 ] = total_batch_size;
1796
1812
}
1797
1813
}
1798
1814
1799
- // The input must be in contiguous CPU/GPU memory.
1800
- std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1801
- if (device_.is_cpu ()) {
1802
- alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1803
- {TRITONSERVER_MEMORY_CPU, 0 }};
1804
- } else {
1805
- alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1806
- }
1807
-
1808
1815
const char * input_buffer;
1809
1816
size_t batchn_byte_size;
1810
1817
TRITONSERVER_MemoryType memory_type;
@@ -1867,15 +1874,14 @@ ModelInstanceState::SetInputTensors(
1867
1874
TRITONSERVER_MemoryType dst_memory_type;
1868
1875
int64_t dst_memory_type_id;
1869
1876
1870
- // Batch inputs are always created on CPU
1871
1877
RESPOND_ALL_AND_SET_NULL_IF_ERROR (
1872
1878
(*responses), responses->size (),
1873
1879
collector->ProcessBatchInput (
1874
- batch_input, nullptr , 0 , {{TRITONSERVER_MEMORY_CPU, 0 }},
1875
- &dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1876
- &dst_memory_type_id));
1880
+ batch_input, nullptr , 0 , alloc_perference, &dst_buffer,
1881
+ &dst_buffer_byte_size, &dst_memory_type, &dst_memory_type_id));
1877
1882
1878
- const auto torch_dtype = ConvertDataTypeToTorchType (batch_input.DataType ());
1883
+ const auto torch_dtype =
1884
+ ConvertDataTypeToTorchType (batch_input.DataType ());
1879
1885
torch::TensorOptions options{torch_dtype.second };
1880
1886
auto updated_options = options.device (torch::kCPU );
1881
1887
@@ -1944,9 +1950,11 @@ ModelInstanceState::ReadOutputTensors(
1944
1950
1945
1951
// Output tensors may not reside on the same device as model
1946
1952
torch::Device tensor_device = output_flat.device ();
1947
- const auto memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1948
- : TRITONSERVER_MEMORY_GPU;
1949
- const auto memory_id = (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1953
+ const auto memory_type = (tensor_device.type () == torch::kCPU )
1954
+ ? TRITONSERVER_MEMORY_CPU
1955
+ : TRITONSERVER_MEMORY_GPU;
1956
+ const auto memory_id =
1957
+ (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1950
1958
1951
1959
// Batch output doesn't support string data type yet, as it is not trivial
1952
1960
// to parse string output
@@ -1963,16 +1971,16 @@ ModelInstanceState::ReadOutputTensors(
1963
1971
return TRITONSERVER_ErrorNew (
1964
1972
TRITONSERVER_ERROR_INVALID_ARG,
1965
1973
(std::string (" output '" ) + name +
1966
- " ' is a scalar which is not supported." )
1974
+ " ' is a scalar which is not supported." )
1967
1975
.c_str ());
1968
1976
}
1969
1977
1970
1978
responder.ProcessTensor (
1971
- name, output_dtype, batchn_shape, output_buffer,
1972
- memory_type, memory_id);
1979
+ name, output_dtype, batchn_shape, output_buffer, memory_type,
1980
+ memory_id);
1973
1981
} else {
1974
1982
responder.ProcessBatchOutput (
1975
- name, *batch_output, output_buffer, memory_type, memory_id);
1983
+ name, *batch_output, output_buffer, memory_type, memory_id);
1976
1984
}
1977
1985
} else if (output_tensors[op_index].isList ()) {
1978
1986
// Custom handling for string/bytes tensor...
0 commit comments