Skip to content

Commit cb5be4d

Browse files
kthuirmccorm4
authored andcommitted
fix: Models should filter outputs based on requested outputs (#366)
* Prune non requested outputs from non-decoupled models * Prune non requested outputs from decoupled models * [chore] Remove redundant copy
1 parent ebc8c6c commit cb5be4d

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

src/infer_request.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,13 @@ InferRequest::InferRequest(
6868
}
6969
}
7070

71-
inputs_ = inputs;
72-
requested_output_names_ = requested_output_names;
7371
#ifdef TRITON_PB_STUB
7472
pb_cancel_ =
7573
std::make_shared<PbCancel>(response_factory_address_, request_address_);
7674
response_sender_ = std::make_shared<ResponseSender>(
7775
request_address_, response_factory_address_, nullptr /* is_decoupled */,
78-
Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_);
76+
RequestedOutputNames(), Stub::GetOrCreateInstance()->SharedMemory(),
77+
pb_cancel_);
7978
#endif
8079
}
8180

@@ -390,7 +389,8 @@ InferRequest::InferRequest(
390389
std::make_shared<PbCancel>(response_factory_address_, request_address_);
391390
response_sender_ = std::make_shared<ResponseSender>(
392391
request_address_, response_factory_address_, is_model_decoupled,
393-
Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_);
392+
RequestedOutputNames(), Stub::GetOrCreateInstance()->SharedMemory(),
393+
pb_cancel_);
394394
#endif
395395
}
396396

src/response_sender.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,15 @@ CheckResponseSenderArguments(
5454

5555
ResponseSender::ResponseSender(
5656
intptr_t request_address, intptr_t response_factory_address,
57-
bool const* is_decoupled, std::unique_ptr<SharedMemoryManager>& shm_pool,
57+
bool const* is_decoupled,
58+
const std::set<std::string>& requested_output_names,
59+
std::unique_ptr<SharedMemoryManager>& shm_pool,
5860
const std::shared_ptr<PbCancel>& pb_cancel)
5961
: request_address_(request_address),
6062
response_factory_address_(response_factory_address),
61-
is_decoupled_(is_decoupled), shm_pool_(shm_pool), pb_cancel_(pb_cancel),
62-
closed_(false), number_of_response_sent_(0)
63+
is_decoupled_(is_decoupled),
64+
requested_output_names_(requested_output_names), shm_pool_(shm_pool),
65+
pb_cancel_(pb_cancel), closed_(false), number_of_response_sent_(0)
6366
{
6467
}
6568

@@ -123,6 +126,9 @@ ResponseSender::Send(
123126

124127
CheckResponseSenderArguments(infer_response, flags);
125128
UpdateStateAndCounters(infer_response, flags);
129+
if (infer_response) {
130+
infer_response->PruneOutputTensors(requested_output_names_);
131+
}
126132

127133
std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance();
128134

src/response_sender.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ class ResponseSender {
3838
public:
3939
ResponseSender(
4040
intptr_t request_address, intptr_t response_factory_address,
41-
bool const* is_decoupled, std::unique_ptr<SharedMemoryManager>& shm_pool,
41+
bool const* is_decoupled,
42+
const std::set<std::string>& requested_output_names,
43+
std::unique_ptr<SharedMemoryManager>& shm_pool,
4244
const std::shared_ptr<PbCancel>& pb_cancel);
4345
~ResponseSender();
4446
void Send(std::shared_ptr<InferResponse> response, const uint32_t flags);
@@ -54,6 +56,7 @@ class ResponseSender {
5456
intptr_t request_address_;
5557
intptr_t response_factory_address_;
5658
bool const* is_decoupled_;
59+
std::set<std::string> requested_output_names_;
5760
std::unique_ptr<SharedMemoryManager>& shm_pool_;
5861
std::shared_ptr<PbCancel> pb_cancel_;
5962

0 commit comments

Comments
 (0)