@@ -110,6 +110,10 @@ class ModelState : public BackendModel {
110
110
{
111
111
return model_outputs_;
112
112
}
113
+ const std::map<std::string, std::pair<int64_t , int64_t >>& ModelOutputs ()
114
+ {
115
+ return model_outputs_;
116
+ }
113
117
114
118
private:
115
119
ModelState (TRITONBACKEND_Model* triton_model);
@@ -544,6 +548,11 @@ class ModelInstanceState : public BackendModelInstance {
544
548
TRITONSERVER_Error* ValidateTypedSequenceControl (
545
549
triton::common::TritonJson::Value& sequence_batching,
546
550
const std::string& control_kind, bool required, bool * have_control);
551
+ void AddInputToMap (
552
+ NamingConvention naming_convention,
553
+ const std::vector<std::string> allowed_inputs,
554
+ const std::string &io_name,
555
+ const uint32_t index);
547
556
TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
548
557
void AddInputToMap (
549
558
NamingConvention naming_convention,
@@ -873,6 +882,8 @@ ModelInstanceState::ValidateTypedSequenceControl(
873
882
874
883
return nullptr ; // success
875
884
}
885
+ void ModelInstanceState::AddInputToMap (NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
886
+ std::string deliminator = " __" ;
876
887
877
888
void
878
889
ModelInstanceState::AddInputToMap (
@@ -1201,7 +1212,7 @@ ModelInstanceState::ValidateOutputs()
1201
1212
TRITONSERVER_ERROR_INTERNAL,
1202
1213
(" Triton only supports 1 dimensional List of String as output "
1203
1214
" for "
1204
- " '" +
1215
+ " '" +
1205
1216
std::string (state_name) + " ' for model '" +
1206
1217
model_state_->Name () + " '" )
1207
1218
.c_str ());
@@ -1768,7 +1779,7 @@ ModelInstanceState::GetNamingConvention(
1768
1779
(" PyTorch model '" + model_state_->Name () +
1769
1780
" ' is using sequence batching with state but state '" +
1770
1781
state_name +
1771
- " ' does not follow the <name>__<index> naming convention. " )
1782
+ " ' does not follow the <name>__<index> naming convention. " )
1772
1783
.c_str ());
1773
1784
} else {
1774
1785
// check if the index part of the name is not an integer
@@ -2316,17 +2327,17 @@ ModelInstanceState::ReadOutputTensors(
2316
2327
responder.ProcessTensor (
2317
2328
name, output_dtype, batchn_shape, output_buffer, memory_type,
2318
2329
memory_id);
2319
- }
2320
- if (output_tensor_pair.second != -1 ) {
2321
- std::vector<TRITONBACKEND_State*> states;
2322
- states = responder.ProcessStateTensor (
2330
+ }
2331
+ if (output_tensor_pair.second != -1 ) {
2332
+ std::vector<TRITONBACKEND_State*> states;
2333
+ states = responder.ProcessStateTensor (
2323
2334
name, output_dtype, batchn_shape, output_buffer, memory_type,
2324
2335
memory_id);
2325
- // Update the states
2326
- for (auto & state : states) {
2327
- RETURN_IF_ERROR (TRITONBACKEND_StateUpdate (state));
2328
- }
2336
+ // Update the states
2337
+ for (auto & state : states) {
2338
+ RETURN_IF_ERROR (TRITONBACKEND_StateUpdate (state));
2329
2339
}
2340
+ }
2330
2341
2331
2342
} else {
2332
2343
responder.ProcessBatchOutput (
@@ -2377,6 +2388,11 @@ ModelInstanceState::ReadOutputTensors(
2377
2388
&response_state, request, name.c_str (),
2378
2389
TRITONSERVER_TYPE_BYTES, batchn_shape.data (),
2379
2390
batchn_shape.size ()));
2391
+ RESPOND_AND_SET_NULL_IF_ERROR (
2392
+ &response, TRITONBACKEND_StateNew (
2393
+ &response_state, request, name.c_str (),
2394
+ TRITONSERVER_TYPE_BYTES, batchn_shape.data (),
2395
+ batchn_shape.size ()));
2380
2396
2381
2397
string_buffer.emplace_back (new std::string ());
2382
2398
cuda_copy |= SetStringStateBuffer (
0 commit comments