Skip to content

Commit da3effb

Browse files
committed
formatting
1 parent 43c90ae commit da3effb

File tree

1 file changed

+56
-40
lines changed

1 file changed

+56
-40
lines changed

src/libtorch.cc

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ class ModelState : public BackendModel {
109109
bool EnabledCacheCleaning() { return enable_cache_cleaning_; }
110110

111111
bool EnabledWeightSharing() { return enable_weight_sharing_; }
112-
const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs() { return model_outputs_; }
112+
const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
113+
{
114+
return model_outputs_;
115+
}
113116

114117
private:
115118
ModelState(TRITONBACKEND_Model* triton_model);
@@ -538,8 +541,7 @@ class ModelInstanceState : public BackendModelInstance {
538541
const std::string& control_kind, bool required, bool* have_control);
539542
void AddInputToMap(
540543
NamingConvention naming_convention,
541-
const std::vector<std::string> allowed_inputs,
542-
const std::string &io_name,
544+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
543545
const uint32_t index);
544546
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
545547
void AddInputToMap(
@@ -812,7 +814,12 @@ ModelInstanceState::ValidateTypedSequenceControl(
812814

813815
return nullptr; // success
814816
}
815-
void ModelInstanceState::AddInputToMap(NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
817+
void
818+
ModelInstanceState::AddInputToMap(
819+
NamingConvention naming_convention,
820+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
821+
const uint32_t index)
822+
{
816823
std::string deliminator = "__";
817824

818825
if (is_dict_input_) {
@@ -1006,7 +1013,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
10061013
}
10071014
triton::common::TritonJson::Value sequence_batching;
10081015
if (model_state_->ModelConfig().Find(
1009-
"sequence_batching", &sequence_batching)){
1016+
"sequence_batching", &sequence_batching)) {
10101017
triton::common::TritonJson::Value states;
10111018
if (sequence_batching.Find("state", &states)) {
10121019
for (size_t i = 0; i < states.ArraySize(); i++) {
@@ -1023,8 +1030,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
10231030
if (!pr.first && (state_dtype != "TYPE_STRING")) {
10241031
return TRITONSERVER_ErrorNew(
10251032
TRITONSERVER_ERROR_INTERNAL,
1026-
("unsupported datatype " + state_dtype + " for input state '" + state_name +
1027-
"' for model '" + model_state_->Name() + "'")
1033+
("unsupported datatype " + state_dtype + " for input state '" +
1034+
state_name + "' for model '" + model_state_->Name() + "'")
10281035
.c_str());
10291036
}
10301037

@@ -1035,10 +1042,11 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
10351042
if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) {
10361043
return TRITONSERVER_ErrorNew(
10371044
TRITONSERVER_ERROR_INTERNAL,
1038-
("Triton only supports 1 dimensional List of String as input for "
1045+
("Triton only supports 1 dimensional List of String as input "
1046+
"for "
10391047
"'" +
1040-
std::string(state_name) + "' for model '" + model_state_->Name() +
1041-
"'")
1048+
std::string(state_name) + "' for model '" +
1049+
model_state_->Name() + "'")
10421050
.c_str());
10431051
}
10441052
}
@@ -1162,8 +1170,8 @@ ModelInstanceState::ValidateOutputs()
11621170
if (!pr.first && (state_dtype != "TYPE_STRING")) {
11631171
return TRITONSERVER_ErrorNew(
11641172
TRITONSERVER_ERROR_INTERNAL,
1165-
("unsupported datatype " + state_dtype + " for state '" + state_name +
1166-
"' for model '" + model_state_->Name() + "'")
1173+
("unsupported datatype " + state_dtype + " for state '" +
1174+
state_name + "' for model '" + model_state_->Name() + "'")
11671175
.c_str());
11681176
}
11691177

@@ -1172,10 +1180,11 @@ ModelInstanceState::ValidateOutputs()
11721180
if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) {
11731181
return TRITONSERVER_ErrorNew(
11741182
TRITONSERVER_ERROR_INTERNAL,
1175-
("Triton only supports 1 dimensional List of String as output for "
1183+
("Triton only supports 1 dimensional List of String as output "
1184+
"for "
11761185
"'" +
1177-
std::string(state_name) + "' for model '" + model_state_->Name() +
1178-
"'")
1186+
std::string(state_name) + "' for model '" +
1187+
model_state_->Name() + "'")
11791188
.c_str());
11801189
}
11811190
}
@@ -1678,7 +1687,8 @@ ModelInstanceState::GetNamingConvention(
16781687
}
16791688

16801689
triton::common::TritonJson::Value sequence_batching;
1681-
if (model_state_->ModelConfig().Find("sequence_batching", &sequence_batching)) {
1690+
if (model_state_->ModelConfig().Find(
1691+
"sequence_batching", &sequence_batching)) {
16821692
// If we need to manage state for the model, then we need to check
16831693
// the naming of the state adheres to both the input and output conventions
16841694
triton::common::TritonJson::Value states;
@@ -1696,16 +1706,17 @@ ModelInstanceState::GetNamingConvention(
16961706
for (size_t i = 0; i < states.ArraySize(); i++) {
16971707
triton::common::TritonJson::Value state;
16981708
RETURN_IF_ERROR(states.IndexAsObject(i, &state));
1699-
std::string name_entry = io_kind == "input" ? "input_name" : "output_name";
1709+
std::string name_entry =
1710+
io_kind == "input" ? "input_name" : "output_name";
17001711
std::string state_name;
1701-
RETURN_IF_ERROR(
1702-
state.MemberAsString(name_entry.c_str(), &state_name));
1712+
RETURN_IF_ERROR(state.MemberAsString(name_entry.c_str(), &state_name));
17031713
int start_pos = state_name.find(deliminator);
17041714
if (start_pos == -1) {
17051715
return TRITONSERVER_ErrorNew(
17061716
TRITONSERVER_ERROR_INVALID_ARG,
17071717
("PyTorch model '" + model_state_->Name() +
1708-
"' is using sequence batching with state but state '" + state_name +
1718+
"' is using sequence batching with state but state '" +
1719+
state_name +
17091720
"' does not follow the <name>__<index> naming convention. ")
17101721
.c_str());
17111722
} else {
@@ -1721,7 +1732,8 @@ ModelInstanceState::GetNamingConvention(
17211732
return TRITONSERVER_ErrorNew(
17221733
TRITONSERVER_ERROR_INVALID_ARG,
17231734
("PyTorch model '" + model_state_->Name() +
1724-
"' is using sequence batching with state but state '" + state_name +
1735+
"' is using sequence batching with state but state '" +
1736+
state_name +
17251737
"' does not follow the <name>__<index> naming convention. ")
17261738
.c_str());
17271739
}
@@ -1912,8 +1924,9 @@ SetStringInputTensor(
19121924
bool
19131925
SetStringBuffer(
19141926
torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
1915-
TRITONBACKEND_Output* response_output, TRITONBACKEND_State* response_state, const size_t tensor_element_count,
1916-
cudaStream_t stream, std::string* serialized, bool state)
1927+
TRITONBACKEND_Output* response_output, TRITONBACKEND_State* response_state,
1928+
const size_t tensor_element_count, cudaStream_t stream,
1929+
std::string* serialized, bool state)
19171930
{
19181931
bool cuda_copy = false;
19191932

@@ -1938,7 +1951,7 @@ SetStringBuffer(
19381951
TRITONSERVER_Error* err;
19391952
void* buffer;
19401953

1941-
if (!state){
1954+
if (!state) {
19421955
auto err = TRITONBACKEND_OutputBuffer(
19431956
response_output, &buffer, serialized->size(), &actual_memory_type,
19441957
&actual_memory_type_id);
@@ -1984,19 +1997,20 @@ SetStringOutputBuffer(
19841997
TRITONBACKEND_Output* response_output, const size_t tensor_element_count,
19851998
cudaStream_t stream, std::string* serialized)
19861999
{
1987-
return SetStringBuffer(tensor, response, response_output, nullptr /* response_state */, tensor_element_count,
1988-
stream, serialized, false /* state */);
1989-
2000+
return SetStringBuffer(
2001+
tensor, response, response_output, nullptr /* response_state */,
2002+
tensor_element_count, stream, serialized, false /* state */);
19902003
}
19912004

19922005
bool
19932006
SetStringStateBuffer(
1994-
torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
2007+
torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
19952008
TRITONBACKEND_State* response_state, const size_t tensor_element_count,
19962009
cudaStream_t stream, std::string* serialized)
19972010
{
1998-
return SetStringBuffer(tensor, response, nullptr /* response_output */, response_state, tensor_element_count,
1999-
stream, serialized, true /* state */);
2011+
return SetStringBuffer(
2012+
tensor, response, nullptr /* response_output */, response_state,
2013+
tensor_element_count, stream, serialized, true /* state */);
20002014
}
20012015

20022016

@@ -2063,8 +2077,8 @@ ModelInstanceState::SetInputTensors(
20632077
// The input must be in contiguous CPU/GPU memory.
20642078
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
20652079
if (device_.is_cpu()) {
2066-
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
2067-
{TRITONSERVER_MEMORY_CPU, 0}};
2080+
alloc_perference = {
2081+
{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}};
20682082
} else {
20692083
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}};
20702084
}
@@ -2176,7 +2190,7 @@ ModelInstanceState::ReadOutputTensors(
21762190
bool cuda_copy = false;
21772191
// The serialized string buffer must be valid until output copies are done
21782192
std::vector<std::unique_ptr<std::string>> string_buffer;
2179-
for (auto &output : model_state_->ModelOutputs()) {
2193+
for (auto& output : model_state_->ModelOutputs()) {
21802194
int op_index = output_index_map_[output.first];
21812195
auto name = output.first;
21822196
auto output_tensor_pair = output.second;
@@ -2239,14 +2253,14 @@ ModelInstanceState::ReadOutputTensors(
22392253
}
22402254
if (output_tensor_pair.first != -1) {
22412255
responder.ProcessTensor(
2242-
name, output_dtype, batchn_shape, output_buffer,
2243-
memory_type, memory_id);
2256+
name, output_dtype, batchn_shape, output_buffer, memory_type,
2257+
memory_id);
22442258
}
22452259
if (output_tensor_pair.second != -1) {
22462260
std::vector<TRITONBACKEND_State*> states;
22472261
states = responder.ProcessStateTensor(
2248-
name, output_dtype, batchn_shape, output_buffer,
2249-
memory_type, memory_id);
2262+
name, output_dtype, batchn_shape, output_buffer, memory_type,
2263+
memory_id);
22502264
// Update the states
22512265
for (auto& state : states) {
22522266
RETURN_IF_ERROR(TRITONBACKEND_StateUpdate(state));
@@ -2297,9 +2311,11 @@ ModelInstanceState::ReadOutputTensors(
22972311
}
22982312
if (output_tensor_pair.second != -1) {
22992313
TRITONBACKEND_State* response_state;
2300-
RESPOND_AND_SET_NULL_IF_ERROR(&response, TRITONBACKEND_StateNew(
2301-
&response_state, request, name.c_str(), TRITONSERVER_TYPE_BYTES,
2302-
batchn_shape.data(), batchn_shape.size()));
2314+
RESPOND_AND_SET_NULL_IF_ERROR(
2315+
&response, TRITONBACKEND_StateNew(
2316+
&response_state, request, name.c_str(),
2317+
TRITONSERVER_TYPE_BYTES, batchn_shape.data(),
2318+
batchn_shape.size()));
23032319

23042320
string_buffer.emplace_back(new std::string());
23052321
cuda_copy |= SetStringStateBuffer(

0 commit comments

Comments
 (0)