Skip to content

Commit b5e445e

Browse files
committed
Review comments
1 parent 3e9dcc5 commit b5e445e

File tree

2 files changed

+59
-41
lines changed

2 files changed

+59
-41
lines changed

src/pb_stub.cc

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
654654
py::list py_request_list =
655655
LoadRequestsFromSharedMemory(request_batch_shm_ptr);
656656
std::unique_ptr<IPCMessage> execute_response;
657-
// IPCMessage::Create(shm_pool_, false /* Inline response */);
657+
// IPCMessage::Create(shm_pool_, false /* Inline response */);
658658

659659
std::optional<AllocatedSharedMemory<char>> response_batch;
660660
bool has_exception = false;
@@ -675,8 +675,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
675675
{
676676
NVTX_RANGE(nvtx_, "PyExecute " + name_);
677677

678-
execute_return =
679-
model_instance_.attr("execute")(py_request_list);
678+
execute_return = model_instance_.attr("execute")(py_request_list);
680679

681680
bool is_coroutine = py::module::import("asyncio")
682681
.attr("iscoroutine")(execute_return)
@@ -688,10 +687,12 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
688687
} else {
689688
py::object coroutine_return =
690689
RunCoroutine(execute_return, false /* in_background */);
691-
ProcessReturnedResponses(py_request_list, coroutine_return, response_batch);
690+
ProcessReturnedResponses(
691+
py_request_list, coroutine_return, response_batch);
692692
}
693693
} else {
694-
ProcessReturnedResponses(py_request_list, execute_return, response_batch);
694+
ProcessReturnedResponses(
695+
py_request_list, execute_return, response_batch);
695696
}
696697
}
697698
}
@@ -712,11 +713,14 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
712713
error_string;
713714
LOG_ERROR << err_message.c_str();
714715
if (!response_batch) {
715-
response_batch = shm_pool_->Construct<char>(sizeof(ResponseBatch) + sizeof(IPCMessageShm));
716-
}
717-
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get() + sizeof(IPCMessageShm));
716+
response_batch = shm_pool_->Construct<char>(
717+
sizeof(ResponseBatch) + sizeof(IPCMessageShm));
718+
}
719+
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
720+
response_batch.value().data_.get() + sizeof(IPCMessageShm));
718721

719-
response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get());
722+
response_batch_shm_ptr =
723+
reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get());
720724
response_batch_shm_ptr->has_error = true;
721725
error_string_shm = PbString::Create(shm_pool_, err_message);
722726
response_batch_shm_ptr->error = error_string_shm->ShmHandle();
@@ -732,14 +736,19 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
732736
}
733737

734738
if (!response_batch) {
735-
response_batch = shm_pool_->Construct<char>(sizeof(ResponseBatch) + sizeof(IPCMessageShm));
736-
ResponseBatch* response_batch_shm_ptr =reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get() + sizeof(IPCMessageShm));
737-
response_batch_shm_ptr->batch_size = 0;
738-
}
739-
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get() + sizeof(IPCMessageShm));
739+
response_batch = shm_pool_->Construct<char>(
740+
sizeof(ResponseBatch) + sizeof(IPCMessageShm));
741+
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
742+
response_batch.value().data_.get() + sizeof(IPCMessageShm));
743+
response_batch_shm_ptr->batch_size = 0;
744+
}
745+
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
746+
response_batch.value().data_.get() + sizeof(IPCMessageShm));
740747
response_batch_shm_ptr->has_error = false;
741748
response_batch_shm_ptr->is_error_set = false;
742-
execute_response = IPCMessage::Create(reinterpret_cast<IPCMessageShm*>(response_batch.value().data_.get()), response_batch.value().handle_);
749+
execute_response = IPCMessage::Create(
750+
reinterpret_cast<IPCMessageShm*>(response_batch.value().data_.get()),
751+
response_batch.value().handle_);
743752
execute_response->Args() = response_batch.value().handle_;
744753
execute_response->InlineResponse() = false;
745754
execute_response->Command() = PYTHONSTUB_ExecuteResponse;
@@ -761,7 +770,8 @@ Stub::ProcessResponse(InferResponse* response)
761770

762771
void
763772
Stub::ProcessReturnedResponses(
764-
py::list py_requests, py::object py_responses_obj, std::optional<AllocatedSharedMemory<char>>& response_batch)
773+
py::list py_requests, py::object py_responses_obj,
774+
std::optional<AllocatedSharedMemory<char>>& response_batch)
765775
{
766776
// Return if there is nothing to process.
767777
if (py::isinstance<py::none>(py_responses_obj)) {
@@ -812,29 +822,34 @@ Stub::ProcessReturnedResponses(
812822

813823
std::shared_ptr<InferResponse> response =
814824
py_responses[i].cast<std::shared_ptr<InferResponse>>();
815-
request->GetResponseSender()->UpdateStateAndCounters(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
825+
request->GetResponseSender()->UpdateStateAndCounters(
826+
response, TRITONSERVER_RESPONSE_COMPLETE_FINAL);
816827
}
817828
}
818-
response_batch = std::move(shm_pool_->Construct<char>(sizeof(IPCMessageShm) +
829+
// Return all the created responses using response_batch. The reason
830+
// that both of the paths are available is that sending the responses
831+
// using response_batch is faster than using `response_sender`.
832+
response_batch = std::move(shm_pool_->Construct<char>(
833+
sizeof(IPCMessageShm) +
819834
requests_size * sizeof(bi::managed_external_buffer::handle_t) +
820835
sizeof(ResponseBatch)));
821-
ResponseBatch* response_batch_shm_ptr =
822-
reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get() + sizeof(IPCMessageShm));
836+
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
837+
response_batch.value().data_.get() + sizeof(IPCMessageShm));
823838

824839
bi::managed_external_buffer::handle_t* responses_shm_handle =
825840
reinterpret_cast<bi::managed_external_buffer::handle_t*>(
826-
response_batch.value().data_.get() + sizeof(ResponseBatch) + sizeof(IPCMessageShm));
827-
828-
for (size_t i = 0; i < responses_size; i++) {
829-
// Check the return type of execute function.
830-
InferRequest* infer_request = py_requests[i].cast<InferRequest*>();
831-
InferResponse* infer_response = py_responses[i].cast<InferResponse*>();
832-
infer_response->PruneOutputTensors(
833-
infer_request->RequestedOutputNames());
834-
ProcessResponse(infer_response);
835-
responses_shm_handle[i] = infer_response->ShmHandle();
836-
}
837-
response_batch_shm_ptr->batch_size = requests_size;
841+
response_batch.value().data_.get() + sizeof(ResponseBatch) +
842+
sizeof(IPCMessageShm));
843+
844+
for (size_t i = 0; i < responses_size; i++) {
845+
// Check the return type of execute function.
846+
InferRequest* infer_request = py_requests[i].cast<InferRequest*>();
847+
InferResponse* infer_response = py_responses[i].cast<InferResponse*>();
848+
infer_response->PruneOutputTensors(infer_request->RequestedOutputNames());
849+
ProcessResponse(infer_response);
850+
responses_shm_handle[i] = infer_response->ShmHandle();
851+
}
852+
response_batch_shm_ptr->batch_size = requests_size;
838853
}
839854

840855
py::object

src/python_be.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ ModelInstanceState::SendMessageAndReceiveResponse(
10231023
std::shared_ptr<std::vector<TRITONBACKEND_Response*>>& responses,
10241024
TRITONBACKEND_Request** requests, const uint32_t request_count)
10251025
{
1026-
SendMessageToStub(message);
1026+
SendMessageToStub(message);
10271027

10281028
bi::managed_external_buffer::handle_t response_message;
10291029
auto error = Stub()->ReceiveMessageFromStub(response_message);
@@ -1224,7 +1224,8 @@ ModelInstanceState::ResponseSendDecoupled(
12241224
if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
12251225
std::unique_ptr<
12261226
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
1227-
lresponse_factory(reinterpret_cast<TRITONBACKEND_ResponseFactory*>(response_factory));
1227+
lresponse_factory(
1228+
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(response_factory));
12281229
}
12291230
}
12301231

@@ -1280,12 +1281,15 @@ ModelInstanceState::ProcessRequests(
12801281
Stub()->StubMessageQueue()->Push(ipc_message->ShmHandle());
12811282
bi::managed_external_buffer::handle_t response_message;
12821283
Stub()->ReceiveMessageFromStub(response_message);
1283-
response = IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), response_message);
1284+
response =
1285+
IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), response_message);
12841286
}
1285-
char* ipc_message_shm = reinterpret_cast<char*>(response->GetAllocatedSharedMemory().data_.get());;
1287+
char* ipc_message_shm =
1288+
reinterpret_cast<char*>(response->GetAllocatedSharedMemory().data_.get());
1289+
;
12861290
ResponseBatch* response_batch_shm_ptr =
12871291
reinterpret_cast<ResponseBatch*>(ipc_message_shm + sizeof(IPCMessageShm));
1288-
1292+
12891293
uint64_t compute_end_ns = 0;
12901294
SET_TIMESTAMP(compute_end_ns);
12911295
reporter.SetComputeEndNs(compute_end_ns);
@@ -1304,10 +1308,10 @@ ModelInstanceState::ProcessRequests(
13041308
}
13051309

13061310
if (response_batch_shm_ptr->batch_size > 0) {
1307-
std::shared_ptr<std::vector<TRITONBACKEND_Response*>> responses(
1308-
new std::vector<TRITONBACKEND_Response*>());
1311+
std::shared_ptr<std::vector<TRITONBACKEND_Response*>> responses(
1312+
new std::vector<TRITONBACKEND_Response*>());
13091313
responses->reserve(request_count);
1310-
for (size_t i = 0; i < request_count; i++) {
1314+
for (size_t i = 0; i < request_count; i++) {
13111315
TRITONBACKEND_Response* response;
13121316
auto err = TRITONBACKEND_ResponseNew(&response, requests[i]);
13131317
if (err == nullptr) {
@@ -1324,7 +1328,6 @@ ModelInstanceState::ProcessRequests(
13241328

13251329
// If the output provided by the model is in GPU, we will pass the list of
13261330
// buffers provided by Triton to the stub process.
1327-
// bool has_gpu_output = false;
13281331
std::vector<bool> requires_deferred_callback;
13291332

13301333
bool has_gpu_output = false;

0 commit comments

Comments
 (0)