Skip to content

Commit 70cb673

Browse files
jamied157Tabrizian
authored andcommitted
formatting
1 parent 647704b commit 70cb673

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

src/libtorch.cc

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ class ModelState : public BackendModel {
110110
{
111111
return model_outputs_;
112112
}
113+
const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
114+
{
115+
return model_outputs_;
116+
}
113117

114118
private:
115119
ModelState(TRITONBACKEND_Model* triton_model);
@@ -544,6 +548,11 @@ class ModelInstanceState : public BackendModelInstance {
544548
TRITONSERVER_Error* ValidateTypedSequenceControl(
545549
triton::common::TritonJson::Value& sequence_batching,
546550
const std::string& control_kind, bool required, bool* have_control);
551+
void AddInputToMap(
552+
NamingConvention naming_convention,
553+
const std::vector<std::string> allowed_inputs,
554+
const std::string &io_name,
555+
const uint32_t index);
547556
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
548557
void AddInputToMap(
549558
NamingConvention naming_convention,
@@ -873,6 +882,8 @@ ModelInstanceState::ValidateTypedSequenceControl(
873882

874883
return nullptr; // success
875884
}
885+
void ModelInstanceState::AddInputToMap(NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
886+
std::string deliminator = "__";
876887

877888
void
878889
ModelInstanceState::AddInputToMap(
@@ -1201,7 +1212,7 @@ ModelInstanceState::ValidateOutputs()
12011212
TRITONSERVER_ERROR_INTERNAL,
12021213
("Triton only supports 1 dimensional List of String as output "
12031214
"for "
1204-
"'" +
1215+
"'" +
12051216
std::string(state_name) + "' for model '" +
12061217
model_state_->Name() + "'")
12071218
.c_str());
@@ -1768,7 +1779,7 @@ ModelInstanceState::GetNamingConvention(
17681779
("PyTorch model '" + model_state_->Name() +
17691780
"' is using sequence batching with state but state '" +
17701781
state_name +
1771-
"' does not follow the <name>__<index> naming convention. ")
1782+
"' does not follow the <name>__<index> naming convention. ")
17721783
.c_str());
17731784
} else {
17741785
// check if the index part of the name is not an integer
@@ -2316,17 +2327,17 @@ ModelInstanceState::ReadOutputTensors(
23162327
responder.ProcessTensor(
23172328
name, output_dtype, batchn_shape, output_buffer, memory_type,
23182329
memory_id);
2319-
}
2320-
if (output_tensor_pair.second != -1) {
2321-
std::vector<TRITONBACKEND_State*> states;
2322-
states = responder.ProcessStateTensor(
2330+
}
2331+
if (output_tensor_pair.second != -1) {
2332+
std::vector<TRITONBACKEND_State*> states;
2333+
states = responder.ProcessStateTensor(
23232334
name, output_dtype, batchn_shape, output_buffer, memory_type,
23242335
memory_id);
2325-
// Update the states
2326-
for (auto& state : states) {
2327-
RETURN_IF_ERROR(TRITONBACKEND_StateUpdate(state));
2328-
}
2336+
// Update the states
2337+
for (auto& state : states) {
2338+
RETURN_IF_ERROR(TRITONBACKEND_StateUpdate(state));
23292339
}
2340+
}
23302341

23312342
} else {
23322343
responder.ProcessBatchOutput(
@@ -2377,6 +2388,11 @@ ModelInstanceState::ReadOutputTensors(
23772388
&response_state, request, name.c_str(),
23782389
TRITONSERVER_TYPE_BYTES, batchn_shape.data(),
23792390
batchn_shape.size()));
2391+
RESPOND_AND_SET_NULL_IF_ERROR(
2392+
&response, TRITONBACKEND_StateNew(
2393+
&response_state, request, name.c_str(),
2394+
TRITONSERVER_TYPE_BYTES, batchn_shape.data(),
2395+
batchn_shape.size()));
23802396

23812397
string_buffer.emplace_back(new std::string());
23822398
cuda_copy |= SetStringStateBuffer(

0 commit comments

Comments
 (0)