Skip to content

Commit fe45fb0

Browse files
dyastremskymc-nv
authored andcommitted
Auto-format, create batch input on current device
1 parent bfca5d0 commit fe45fb0

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

src/libtorch.cc

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,8 @@ class ModelInstanceState : public BackendModelInstance {
505505
const std::string& control_kind, bool required, bool* have_control);
506506
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
507507
void AddInputToMap(
508-
NamingConvention naming_convention,
509-
const std::vector<std::string> allowed_inputs,
510-
const std::string &io_name,
508+
NamingConvention naming_convention,
509+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
511510
const uint32_t index);
512511
TRITONSERVER_Error* ValidateOutputs();
513512
void Execute(
@@ -771,7 +770,12 @@ ModelInstanceState::ValidateTypedSequenceControl(
771770
return nullptr; // success
772771
}
773772

774-
void ModelInstanceState::AddInputToMap(NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
773+
void
774+
ModelInstanceState::AddInputToMap(
775+
NamingConvention naming_convention,
776+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
777+
const uint32_t index)
778+
{
775779
std::string deliminator = "__";
776780

777781
if (is_dict_input_) {
@@ -924,11 +928,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
924928
}
925929

926930
triton::common::TritonJson::Value batch_inputs;
927-
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
931+
RETURN_IF_ERROR(
932+
model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
928933
size_t i = 0;
929934
for (const auto& batch_input : StateForModel()->BatchInputs()) {
930935
for (const auto& input_name : batch_input.TargetNames()) {
931-
AddInputToMap(naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
936+
AddInputToMap(
937+
naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
932938
i++;
933939
}
934940
}
@@ -1754,6 +1760,16 @@ ModelInstanceState::SetInputTensors(
17541760
RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count));
17551761

17561762
input_tensors->resize(input_count + batch_input_count_);
1763+
1764+
// The inputs must be in contiguous CPU/GPU memory.
1765+
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
1766+
if (device_.is_cpu()) {
1767+
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1768+
{TRITONSERVER_MEMORY_CPU, 0}};
1769+
} else {
1770+
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}};
1771+
}
1772+
17571773
for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) {
17581774
TRITONBACKEND_Input* input;
17591775
RETURN_IF_ERROR(
@@ -1797,15 +1813,6 @@ ModelInstanceState::SetInputTensors(
17971813
}
17981814
}
17991815

1800-
// The input must be in contiguous CPU/GPU memory.
1801-
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
1802-
if (device_.is_cpu()) {
1803-
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1804-
{TRITONSERVER_MEMORY_CPU, 0}};
1805-
} else {
1806-
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}};
1807-
}
1808-
18091816
const char* input_buffer;
18101817
size_t batchn_byte_size;
18111818
TRITONSERVER_MemoryType memory_type;
@@ -1868,15 +1875,14 @@ ModelInstanceState::SetInputTensors(
18681875
TRITONSERVER_MemoryType dst_memory_type;
18691876
int64_t dst_memory_type_id;
18701877

1871-
// Batch inputs are always created on CPU
18721878
RESPOND_ALL_AND_SET_NULL_IF_ERROR(
18731879
(*responses), responses->size(),
18741880
collector->ProcessBatchInput(
1875-
batch_input, nullptr, 0, {{TRITONSERVER_MEMORY_CPU, 0}},
1876-
&dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1877-
&dst_memory_type_id));
1881+
batch_input, nullptr, 0, alloc_perference, &dst_buffer,
1882+
&dst_buffer_byte_size, &dst_memory_type, &dst_memory_type_id));
18781883

1879-
const auto torch_dtype = ConvertDataTypeToTorchType(batch_input.DataType());
1884+
const auto torch_dtype =
1885+
ConvertDataTypeToTorchType(batch_input.DataType());
18801886
torch::TensorOptions options{torch_dtype.second};
18811887
auto updated_options = options.device(torch::kCPU);
18821888

0 commit comments

Comments
 (0)