@@ -104,7 +104,10 @@ class ModelState : public BackendModel {
104
104
bool EnabledCacheCleaning () { return enable_cache_cleaning_; }
105
105
106
106
bool EnabledWeightSharing () { return enable_weight_sharing_; }
107
- const std::map<std::string, std::pair<int64_t , int64_t >>& ModelOutputs () { return model_outputs_; }
107
+ const std::map<std::string, std::pair<int64_t , int64_t >>& ModelOutputs ()
108
+ {
109
+ return model_outputs_;
110
+ }
108
111
109
112
private:
110
113
ModelState (TRITONBACKEND_Model* triton_model);
@@ -533,8 +536,7 @@ class ModelInstanceState : public BackendModelInstance {
533
536
const std::string& control_kind, bool required, bool * have_control);
534
537
void AddInputToMap (
535
538
NamingConvention naming_convention,
536
- const std::vector<std::string> allowed_inputs,
537
- const std::string &io_name,
539
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
538
540
const uint32_t index);
539
541
TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
540
542
TRITONSERVER_Error* ValidateOutputs ();
@@ -796,7 +798,12 @@ ModelInstanceState::ValidateTypedSequenceControl(
796
798
797
799
return nullptr ; // success
798
800
}
799
- void ModelInstanceState::AddInputToMap (NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
801
+ void
802
+ ModelInstanceState::AddInputToMap (
803
+ NamingConvention naming_convention,
804
+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
805
+ const uint32_t index)
806
+ {
800
807
std::string deliminator = " __" ;
801
808
802
809
if (is_dict_input_) {
@@ -950,7 +957,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
950
957
}
951
958
triton::common::TritonJson::Value sequence_batching;
952
959
if (model_state_->ModelConfig ().Find (
953
- " sequence_batching" , &sequence_batching)){
960
+ " sequence_batching" , &sequence_batching)) {
954
961
triton::common::TritonJson::Value states;
955
962
if (sequence_batching.Find (" state" , &states)) {
956
963
for (size_t i = 0 ; i < states.ArraySize (); i++) {
@@ -967,8 +974,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
967
974
if (!pr.first && (state_dtype != " TYPE_STRING" )) {
968
975
return TRITONSERVER_ErrorNew (
969
976
TRITONSERVER_ERROR_INTERNAL,
970
- (" unsupported datatype " + state_dtype + " for input state '" + state_name +
971
- " ' for model '" + model_state_->Name () + " '" )
977
+ (" unsupported datatype " + state_dtype + " for input state '" +
978
+ state_name + " ' for model '" + model_state_->Name () + " '" )
972
979
.c_str ());
973
980
}
974
981
@@ -979,10 +986,11 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
979
986
if ((dims.size () + (supports_batching_ ? 1 : 0 )) > 1 ) {
980
987
return TRITONSERVER_ErrorNew (
981
988
TRITONSERVER_ERROR_INTERNAL,
982
- (" Triton only supports 1 dimensional List of String as input for "
989
+ (" Triton only supports 1 dimensional List of String as input "
990
+ " for "
983
991
" '" +
984
- std::string (state_name) + " ' for model '" + model_state_-> Name () +
985
- " '" )
992
+ std::string (state_name) + " ' for model '" +
993
+ model_state_-> Name () + " '" )
986
994
.c_str ());
987
995
}
988
996
}
@@ -1094,8 +1102,8 @@ ModelInstanceState::ValidateOutputs()
1094
1102
if (!pr.first && (state_dtype != " TYPE_STRING" )) {
1095
1103
return TRITONSERVER_ErrorNew (
1096
1104
TRITONSERVER_ERROR_INTERNAL,
1097
- (" unsupported datatype " + state_dtype + " for state '" + state_name +
1098
- " ' for model '" + model_state_->Name () + " '" )
1105
+ (" unsupported datatype " + state_dtype + " for state '" +
1106
+ state_name + " ' for model '" + model_state_->Name () + " '" )
1099
1107
.c_str ());
1100
1108
}
1101
1109
@@ -1104,10 +1112,11 @@ ModelInstanceState::ValidateOutputs()
1104
1112
if ((dims.size () + (supports_batching_ ? 1 : 0 )) > 1 ) {
1105
1113
return TRITONSERVER_ErrorNew (
1106
1114
TRITONSERVER_ERROR_INTERNAL,
1107
- (" Triton only supports 1 dimensional List of String as output for "
1115
+ (" Triton only supports 1 dimensional List of String as output "
1116
+ " for "
1108
1117
" '" +
1109
- std::string (state_name) + " ' for model '" + model_state_-> Name () +
1110
- " '" )
1118
+ std::string (state_name) + " ' for model '" +
1119
+ model_state_-> Name () + " '" )
1111
1120
.c_str ());
1112
1121
}
1113
1122
}
@@ -1610,7 +1619,8 @@ ModelInstanceState::GetNamingConvention(
1610
1619
}
1611
1620
1612
1621
triton::common::TritonJson::Value sequence_batching;
1613
- if (model_state_->ModelConfig ().Find (" sequence_batching" , &sequence_batching)) {
1622
+ if (model_state_->ModelConfig ().Find (
1623
+ " sequence_batching" , &sequence_batching)) {
1614
1624
// If we need to manage state for the model, then we need to check
1615
1625
// the naming of the state adheres to both the input and output conventions
1616
1626
triton::common::TritonJson::Value states;
@@ -1628,16 +1638,17 @@ ModelInstanceState::GetNamingConvention(
1628
1638
for (size_t i = 0 ; i < states.ArraySize (); i++) {
1629
1639
triton::common::TritonJson::Value state;
1630
1640
RETURN_IF_ERROR (states.IndexAsObject (i, &state));
1631
- std::string name_entry = io_kind == " input" ? " input_name" : " output_name" ;
1641
+ std::string name_entry =
1642
+ io_kind == " input" ? " input_name" : " output_name" ;
1632
1643
std::string state_name;
1633
- RETURN_IF_ERROR (
1634
- state.MemberAsString (name_entry.c_str (), &state_name));
1644
+ RETURN_IF_ERROR (state.MemberAsString (name_entry.c_str (), &state_name));
1635
1645
int start_pos = state_name.find (deliminator);
1636
1646
if (start_pos == -1 ) {
1637
1647
return TRITONSERVER_ErrorNew (
1638
1648
TRITONSERVER_ERROR_INVALID_ARG,
1639
1649
(" PyTorch model '" + model_state_->Name () +
1640
- " ' is using sequence batching with state but state '" + state_name +
1650
+ " ' is using sequence batching with state but state '" +
1651
+ state_name +
1641
1652
" ' does not follow the <name>__<index> naming convention. " )
1642
1653
.c_str ());
1643
1654
} else {
@@ -1653,7 +1664,8 @@ ModelInstanceState::GetNamingConvention(
1653
1664
return TRITONSERVER_ErrorNew (
1654
1665
TRITONSERVER_ERROR_INVALID_ARG,
1655
1666
(" PyTorch model '" + model_state_->Name () +
1656
- " ' is using sequence batching with state but state '" + state_name +
1667
+ " ' is using sequence batching with state but state '" +
1668
+ state_name +
1657
1669
" ' does not follow the <name>__<index> naming convention. " )
1658
1670
.c_str ());
1659
1671
}
@@ -1844,8 +1856,9 @@ SetStringInputTensor(
1844
1856
bool
1845
1857
SetStringBuffer (
1846
1858
torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
1847
- TRITONBACKEND_Output* response_output, TRITONBACKEND_State* response_state, const size_t tensor_element_count,
1848
- cudaStream_t stream, std::string* serialized, bool state)
1859
+ TRITONBACKEND_Output* response_output, TRITONBACKEND_State* response_state,
1860
+ const size_t tensor_element_count, cudaStream_t stream,
1861
+ std::string* serialized, bool state)
1849
1862
{
1850
1863
bool cuda_copy = false ;
1851
1864
@@ -1870,7 +1883,7 @@ SetStringBuffer(
1870
1883
TRITONSERVER_Error* err;
1871
1884
void * buffer;
1872
1885
1873
- if (!state){
1886
+ if (!state) {
1874
1887
auto err = TRITONBACKEND_OutputBuffer (
1875
1888
response_output, &buffer, serialized->size (), &actual_memory_type,
1876
1889
&actual_memory_type_id);
@@ -1916,19 +1929,20 @@ SetStringOutputBuffer(
1916
1929
TRITONBACKEND_Output* response_output, const size_t tensor_element_count,
1917
1930
cudaStream_t stream, std::string* serialized)
1918
1931
{
1919
- return SetStringBuffer (tensor, response, response_output, nullptr /* response_state */ , tensor_element_count,
1920
- stream, serialized, false /* state */ );
1921
-
1932
+ return SetStringBuffer (
1933
+ tensor, response, response_output, nullptr /* response_state */ ,
1934
+ tensor_element_count, stream, serialized, false /* state */ );
1922
1935
}
1923
1936
1924
1937
bool
1925
1938
SetStringStateBuffer (
1926
- torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
1939
+ torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
1927
1940
TRITONBACKEND_State* response_state, const size_t tensor_element_count,
1928
1941
cudaStream_t stream, std::string* serialized)
1929
1942
{
1930
- return SetStringBuffer (tensor, response, nullptr /* response_output */ , response_state, tensor_element_count,
1931
- stream, serialized, true /* state */ );
1943
+ return SetStringBuffer (
1944
+ tensor, response, nullptr /* response_output */ , response_state,
1945
+ tensor_element_count, stream, serialized, true /* state */ );
1932
1946
}
1933
1947
1934
1948
@@ -1983,9 +1997,9 @@ ModelInstanceState::SetInputTensors(
1983
1997
1984
1998
batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
1985
1999
}
1986
- }
1987
- else {
1988
- batchn_shape = std::vector<int64_t >(input_shape, input_shape + input_dims_count);
2000
+ } else {
2001
+ batchn_shape =
2002
+ std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1989
2003
if (supports_batching_) {
1990
2004
batchn_shape[0 ] = total_batch_size;
1991
2005
}
@@ -1994,8 +2008,8 @@ ModelInstanceState::SetInputTensors(
1994
2008
// The input must be in contiguous CPU/GPU memory.
1995
2009
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1996
2010
if (device_.is_cpu ()) {
1997
- alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1998
- {TRITONSERVER_MEMORY_CPU, 0 }};
2011
+ alloc_perference = {
2012
+ {TRITONSERVER_MEMORY_CPU_PINNED, 0 }, {TRITONSERVER_MEMORY_CPU, 0 }};
1999
2013
} else {
2000
2014
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
2001
2015
}
@@ -2073,7 +2087,7 @@ ModelInstanceState::ReadOutputTensors(
2073
2087
bool cuda_copy = false ;
2074
2088
// The serialized string buffer must be valid until output copies are done
2075
2089
std::vector<std::unique_ptr<std::string>> string_buffer;
2076
- for (auto & output : model_state_->ModelOutputs ()) {
2090
+ for (auto & output : model_state_->ModelOutputs ()) {
2077
2091
int op_index = output_index_map_[output.first ];
2078
2092
auto name = output.first ;
2079
2093
auto output_tensor_pair = output.second ;
@@ -2110,9 +2124,11 @@ ModelInstanceState::ReadOutputTensors(
2110
2124
2111
2125
// Output tensors may not reside on the same device as model
2112
2126
torch::Device tensor_device = output_flat.device ();
2113
- const auto memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
2127
+ const auto memory_type = (tensor_device.type () == torch::kCPU )
2128
+ ? TRITONSERVER_MEMORY_CPU
2114
2129
: TRITONSERVER_MEMORY_GPU;
2115
- const auto memory_id = (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
2130
+ const auto memory_id =
2131
+ (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
2116
2132
2117
2133
// Batch output doesn't support string data type yet, as it is not trivial
2118
2134
// to parse string output
@@ -2134,14 +2150,14 @@ ModelInstanceState::ReadOutputTensors(
2134
2150
}
2135
2151
if (output_tensor_pair.first != -1 ) {
2136
2152
responder.ProcessTensor (
2137
- name, output_dtype, batchn_shape, output_buffer,
2138
- memory_type, memory_id);
2153
+ name, output_dtype, batchn_shape, output_buffer, memory_type,
2154
+ memory_id);
2139
2155
}
2140
2156
if (output_tensor_pair.second != -1 ) {
2141
2157
std::vector<TRITONBACKEND_State*> states;
2142
2158
states = responder.ProcessStateTensor (
2143
- name, output_dtype, batchn_shape, output_buffer,
2144
- memory_type, memory_id);
2159
+ name, output_dtype, batchn_shape, output_buffer, memory_type ,
2160
+ memory_id);
2145
2161
// Update the states
2146
2162
for (auto & state : states) {
2147
2163
RETURN_IF_ERROR (TRITONBACKEND_StateUpdate (state));
@@ -2192,9 +2208,11 @@ ModelInstanceState::ReadOutputTensors(
2192
2208
}
2193
2209
if (output_tensor_pair.second != -1 ) {
2194
2210
TRITONBACKEND_State* response_state;
2195
- RESPOND_AND_SET_NULL_IF_ERROR (&response, TRITONBACKEND_StateNew (
2196
- &response_state, request, name.c_str (), TRITONSERVER_TYPE_BYTES,
2197
- batchn_shape.data (), batchn_shape.size ()));
2211
+ RESPOND_AND_SET_NULL_IF_ERROR (
2212
+ &response, TRITONBACKEND_StateNew (
2213
+ &response_state, request, name.c_str (),
2214
+ TRITONSERVER_TYPE_BYTES, batchn_shape.data (),
2215
+ batchn_shape.size ()));
2198
2216
2199
2217
string_buffer.emplace_back (new std::string ());
2200
2218
cuda_copy |= SetStringStateBuffer (
0 commit comments