Skip to content

Commit f4533fb

Browse files
committed
formatting
1 parent d40c3ab commit f4533fb

File tree

1 file changed

+63
-45
lines changed

1 file changed

+63
-45
lines changed

src/libtorch.cc

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ class ModelState : public BackendModel {
104104
bool EnabledCacheCleaning() { return enable_cache_cleaning_; }
105105

106106
bool EnabledWeightSharing() { return enable_weight_sharing_; }
107-
const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs() { return model_outputs_; }
107+
const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
108+
{
109+
return model_outputs_;
110+
}
108111

109112
private:
110113
ModelState(TRITONBACKEND_Model* triton_model);
@@ -533,8 +536,7 @@ class ModelInstanceState : public BackendModelInstance {
533536
const std::string& control_kind, bool required, bool* have_control);
534537
void AddInputToMap(
535538
NamingConvention naming_convention,
536-
const std::vector<std::string> allowed_inputs,
537-
const std::string &io_name,
539+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
538540
const uint32_t index);
539541
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
540542
TRITONSERVER_Error* ValidateOutputs();
@@ -796,7 +798,12 @@ ModelInstanceState::ValidateTypedSequenceControl(
796798

797799
return nullptr; // success
798800
}
799-
void ModelInstanceState::AddInputToMap(NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
801+
void
802+
ModelInstanceState::AddInputToMap(
803+
NamingConvention naming_convention,
804+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
805+
const uint32_t index)
806+
{
800807
std::string deliminator = "__";
801808

802809
if (is_dict_input_) {
@@ -950,7 +957,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
950957
}
951958
triton::common::TritonJson::Value sequence_batching;
952959
if (model_state_->ModelConfig().Find(
953-
"sequence_batching", &sequence_batching)){
960+
"sequence_batching", &sequence_batching)) {
954961
triton::common::TritonJson::Value states;
955962
if (sequence_batching.Find("state", &states)) {
956963
for (size_t i = 0; i < states.ArraySize(); i++) {
@@ -967,8 +974,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
967974
if (!pr.first && (state_dtype != "TYPE_STRING")) {
968975
return TRITONSERVER_ErrorNew(
969976
TRITONSERVER_ERROR_INTERNAL,
970-
("unsupported datatype " + state_dtype + " for input state '" + state_name +
971-
"' for model '" + model_state_->Name() + "'")
977+
("unsupported datatype " + state_dtype + " for input state '" +
978+
state_name + "' for model '" + model_state_->Name() + "'")
972979
.c_str());
973980
}
974981

@@ -979,10 +986,11 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
979986
if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) {
980987
return TRITONSERVER_ErrorNew(
981988
TRITONSERVER_ERROR_INTERNAL,
982-
("Triton only supports 1 dimensional List of String as input for "
989+
("Triton only supports 1 dimensional List of String as input "
990+
"for "
983991
"'" +
984-
std::string(state_name) + "' for model '" + model_state_->Name() +
985-
"'")
992+
std::string(state_name) + "' for model '" +
993+
model_state_->Name() + "'")
986994
.c_str());
987995
}
988996
}
@@ -1094,8 +1102,8 @@ ModelInstanceState::ValidateOutputs()
10941102
if (!pr.first && (state_dtype != "TYPE_STRING")) {
10951103
return TRITONSERVER_ErrorNew(
10961104
TRITONSERVER_ERROR_INTERNAL,
1097-
("unsupported datatype " + state_dtype + " for state '" + state_name +
1098-
"' for model '" + model_state_->Name() + "'")
1105+
("unsupported datatype " + state_dtype + " for state '" +
1106+
state_name + "' for model '" + model_state_->Name() + "'")
10991107
.c_str());
11001108
}
11011109

@@ -1104,10 +1112,11 @@ ModelInstanceState::ValidateOutputs()
11041112
if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) {
11051113
return TRITONSERVER_ErrorNew(
11061114
TRITONSERVER_ERROR_INTERNAL,
1107-
("Triton only supports 1 dimensional List of String as output for "
1115+
("Triton only supports 1 dimensional List of String as output "
1116+
"for "
11081117
"'" +
1109-
std::string(state_name) + "' for model '" + model_state_->Name() +
1110-
"'")
1118+
std::string(state_name) + "' for model '" +
1119+
model_state_->Name() + "'")
11111120
.c_str());
11121121
}
11131122
}
@@ -1610,7 +1619,8 @@ ModelInstanceState::GetNamingConvention(
16101619
}
16111620

16121621
triton::common::TritonJson::Value sequence_batching;
1613-
if (model_state_->ModelConfig().Find("sequence_batching", &sequence_batching)) {
1622+
if (model_state_->ModelConfig().Find(
1623+
"sequence_batching", &sequence_batching)) {
16141624
// If we need to manage state for the model, then we need to check
16151625
// the naming of the state adheres to both the input and output conventions
16161626
triton::common::TritonJson::Value states;
@@ -1628,16 +1638,17 @@ ModelInstanceState::GetNamingConvention(
16281638
for (size_t i = 0; i < states.ArraySize(); i++) {
16291639
triton::common::TritonJson::Value state;
16301640
RETURN_IF_ERROR(states.IndexAsObject(i, &state));
1631-
std::string name_entry = io_kind == "input" ? "input_name" : "output_name";
1641+
std::string name_entry =
1642+
io_kind == "input" ? "input_name" : "output_name";
16321643
std::string state_name;
1633-
RETURN_IF_ERROR(
1634-
state.MemberAsString(name_entry.c_str(), &state_name));
1644+
RETURN_IF_ERROR(state.MemberAsString(name_entry.c_str(), &state_name));
16351645
int start_pos = state_name.find(deliminator);
16361646
if (start_pos == -1) {
16371647
return TRITONSERVER_ErrorNew(
16381648
TRITONSERVER_ERROR_INVALID_ARG,
16391649
("PyTorch model '" + model_state_->Name() +
1640-
"' is using sequence batching with state but state '" + state_name +
1650+
"' is using sequence batching with state but state '" +
1651+
state_name +
16411652
"' does not follow the <name>__<index> naming convention. ")
16421653
.c_str());
16431654
} else {
@@ -1653,7 +1664,8 @@ ModelInstanceState::GetNamingConvention(
16531664
return TRITONSERVER_ErrorNew(
16541665
TRITONSERVER_ERROR_INVALID_ARG,
16551666
("PyTorch model '" + model_state_->Name() +
1656-
"' is using sequence batching with state but state '" + state_name +
1667+
"' is using sequence batching with state but state '" +
1668+
state_name +
16571669
"' does not follow the <name>__<index> naming convention. ")
16581670
.c_str());
16591671
}
@@ -1844,8 +1856,9 @@ SetStringInputTensor(
18441856
bool
18451857
SetStringBuffer(
18461858
torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
1847-
TRITONBACKEND_Output* response_output, TRITONBACKEND_State* response_state, const size_t tensor_element_count,
1848-
cudaStream_t stream, std::string* serialized, bool state)
1859+
TRITONBACKEND_Output* response_output, TRITONBACKEND_State* response_state,
1860+
const size_t tensor_element_count, cudaStream_t stream,
1861+
std::string* serialized, bool state)
18491862
{
18501863
bool cuda_copy = false;
18511864

@@ -1870,7 +1883,7 @@ SetStringBuffer(
18701883
TRITONSERVER_Error* err;
18711884
void* buffer;
18721885

1873-
if (!state){
1886+
if (!state) {
18741887
auto err = TRITONBACKEND_OutputBuffer(
18751888
response_output, &buffer, serialized->size(), &actual_memory_type,
18761889
&actual_memory_type_id);
@@ -1916,19 +1929,20 @@ SetStringOutputBuffer(
19161929
TRITONBACKEND_Output* response_output, const size_t tensor_element_count,
19171930
cudaStream_t stream, std::string* serialized)
19181931
{
1919-
return SetStringBuffer(tensor, response, response_output, nullptr /* response_state */, tensor_element_count,
1920-
stream, serialized, false /* state */);
1921-
1932+
return SetStringBuffer(
1933+
tensor, response, response_output, nullptr /* response_state */,
1934+
tensor_element_count, stream, serialized, false /* state */);
19221935
}
19231936

19241937
bool
19251938
SetStringStateBuffer(
1926-
torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
1939+
torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
19271940
TRITONBACKEND_State* response_state, const size_t tensor_element_count,
19281941
cudaStream_t stream, std::string* serialized)
19291942
{
1930-
return SetStringBuffer(tensor, response, nullptr /* response_output */, response_state, tensor_element_count,
1931-
stream, serialized, true /* state */);
1943+
return SetStringBuffer(
1944+
tensor, response, nullptr /* response_output */, response_state,
1945+
tensor_element_count, stream, serialized, true /* state */);
19321946
}
19331947

19341948

@@ -1983,9 +1997,9 @@ ModelInstanceState::SetInputTensors(
19831997

19841998
batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
19851999
}
1986-
}
1987-
else {
1988-
batchn_shape = std::vector<int64_t>(input_shape, input_shape + input_dims_count);
2000+
} else {
2001+
batchn_shape =
2002+
std::vector<int64_t>(input_shape, input_shape + input_dims_count);
19892003
if (supports_batching_) {
19902004
batchn_shape[0] = total_batch_size;
19912005
}
@@ -1994,8 +2008,8 @@ ModelInstanceState::SetInputTensors(
19942008
// The input must be in contiguous CPU/GPU memory.
19952009
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
19962010
if (device_.is_cpu()) {
1997-
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1998-
{TRITONSERVER_MEMORY_CPU, 0}};
2011+
alloc_perference = {
2012+
{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}};
19992013
} else {
20002014
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}};
20012015
}
@@ -2073,7 +2087,7 @@ ModelInstanceState::ReadOutputTensors(
20732087
bool cuda_copy = false;
20742088
// The serialized string buffer must be valid until output copies are done
20752089
std::vector<std::unique_ptr<std::string>> string_buffer;
2076-
for (auto &output : model_state_->ModelOutputs()) {
2090+
for (auto& output : model_state_->ModelOutputs()) {
20772091
int op_index = output_index_map_[output.first];
20782092
auto name = output.first;
20792093
auto output_tensor_pair = output.second;
@@ -2110,9 +2124,11 @@ ModelInstanceState::ReadOutputTensors(
21102124

21112125
// Output tensors may not reside on the same device as model
21122126
torch::Device tensor_device = output_flat.device();
2113-
const auto memory_type = (tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
2127+
const auto memory_type = (tensor_device.type() == torch::kCPU)
2128+
? TRITONSERVER_MEMORY_CPU
21142129
: TRITONSERVER_MEMORY_GPU;
2115-
const auto memory_id = (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
2130+
const auto memory_id =
2131+
(tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
21162132

21172133
// Batch output doesn't support string data type yet, as it is not trivial
21182134
// to parse string output
@@ -2134,14 +2150,14 @@ ModelInstanceState::ReadOutputTensors(
21342150
}
21352151
if (output_tensor_pair.first != -1) {
21362152
responder.ProcessTensor(
2137-
name, output_dtype, batchn_shape, output_buffer,
2138-
memory_type, memory_id);
2153+
name, output_dtype, batchn_shape, output_buffer, memory_type,
2154+
memory_id);
21392155
}
21402156
if (output_tensor_pair.second != -1) {
21412157
std::vector<TRITONBACKEND_State*> states;
21422158
states = responder.ProcessStateTensor(
2143-
name, output_dtype, batchn_shape, output_buffer,
2144-
memory_type, memory_id);
2159+
name, output_dtype, batchn_shape, output_buffer, memory_type,
2160+
memory_id);
21452161
// Update the states
21462162
for (auto& state : states) {
21472163
RETURN_IF_ERROR(TRITONBACKEND_StateUpdate(state));
@@ -2192,9 +2208,11 @@ ModelInstanceState::ReadOutputTensors(
21922208
}
21932209
if (output_tensor_pair.second != -1) {
21942210
TRITONBACKEND_State* response_state;
2195-
RESPOND_AND_SET_NULL_IF_ERROR(&response, TRITONBACKEND_StateNew(
2196-
&response_state, request, name.c_str(), TRITONSERVER_TYPE_BYTES,
2197-
batchn_shape.data(), batchn_shape.size()));
2211+
RESPOND_AND_SET_NULL_IF_ERROR(
2212+
&response, TRITONBACKEND_StateNew(
2213+
&response_state, request, name.c_str(),
2214+
TRITONSERVER_TYPE_BYTES, batchn_shape.data(),
2215+
batchn_shape.size()));
21982216

21992217
string_buffer.emplace_back(new std::string());
22002218
cuda_copy |= SetStringStateBuffer(

0 commit comments

Comments
 (0)