Skip to content

Commit 0b291e4

Browse files
committed
Auto-format, create batch input on current device
1 parent 5506203 commit 0b291e4

File tree

1 file changed

+40
-32
lines changed

1 file changed

+40
-32
lines changed

src/libtorch.cc

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,8 @@ class ModelInstanceState : public BackendModelInstance {
504504
const std::string& control_kind, bool required, bool* have_control);
505505
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
506506
void AddInputToMap(
507-
NamingConvention naming_convention,
508-
const std::vector<std::string> allowed_inputs,
509-
const std::string &io_name,
507+
NamingConvention naming_convention,
508+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
510509
const uint32_t index);
511510
TRITONSERVER_Error* ValidateOutputs();
512511
void Execute(
@@ -770,7 +769,12 @@ ModelInstanceState::ValidateTypedSequenceControl(
770769
return nullptr; // success
771770
}
772771

773-
void ModelInstanceState::AddInputToMap(NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
772+
void
773+
ModelInstanceState::AddInputToMap(
774+
NamingConvention naming_convention,
775+
const std::vector<std::string> allowed_inputs, const std::string& io_name,
776+
const uint32_t index)
777+
{
774778
std::string deliminator = "__";
775779

776780
if (is_dict_input_) {
@@ -923,11 +927,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
923927
}
924928

925929
triton::common::TritonJson::Value batch_inputs;
926-
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
930+
RETURN_IF_ERROR(
931+
model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
927932
size_t i = 0;
928933
for (const auto& batch_input : StateForModel()->BatchInputs()) {
929934
for (const auto& input_name : batch_input.TargetNames()) {
930-
AddInputToMap(naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
935+
AddInputToMap(
936+
naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
931937
i++;
932938
}
933939
}
@@ -1338,12 +1344,12 @@ ModelInstanceState::Execute(
13381344
torch::jit::overrideCanFuseOnCPU(false);
13391345
torch::jit::overrideCanFuseOnGPU(false);
13401346
torch::jit::setTensorExprFuserEnabled(false);
1341-
torch::jit::fuser::cuda::setEnabled(true);
1347+
torch::jit::fuser::cuda::setEnabled(true);
13421348
} else {
13431349
torch::jit::overrideCanFuseOnCPU(true);
13441350
torch::jit::overrideCanFuseOnGPU(true);
13451351
torch::jit::setTensorExprFuserEnabled(true);
1346-
torch::jit::fuser::cuda::setEnabled(false);
1352+
torch::jit::fuser::cuda::setEnabled(false);
13471353
}
13481354
}
13491355

@@ -1753,6 +1759,16 @@ ModelInstanceState::SetInputTensors(
17531759
RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count));
17541760

17551761
input_tensors->resize(input_count + batch_input_count_);
1762+
1763+
// The inputs must be in contiguous CPU/GPU memory.
1764+
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
1765+
if (device_.is_cpu()) {
1766+
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1767+
{TRITONSERVER_MEMORY_CPU, 0}};
1768+
} else {
1769+
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}};
1770+
}
1771+
17561772
for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) {
17571773
TRITONBACKEND_Input* input;
17581774
RETURN_IF_ERROR(
@@ -1788,23 +1804,14 @@ ModelInstanceState::SetInputTensors(
17881804

17891805
batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
17901806
}
1791-
}
1792-
else {
1793-
batchn_shape = std::vector<int64_t>(input_shape, input_shape + input_dims_count);
1807+
} else {
1808+
batchn_shape =
1809+
std::vector<int64_t>(input_shape, input_shape + input_dims_count);
17941810
if (supports_batching_) {
17951811
batchn_shape[0] = total_batch_size;
17961812
}
17971813
}
17981814

1799-
// The input must be in contiguous CPU/GPU memory.
1800-
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
1801-
if (device_.is_cpu()) {
1802-
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1803-
{TRITONSERVER_MEMORY_CPU, 0}};
1804-
} else {
1805-
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}};
1806-
}
1807-
18081815
const char* input_buffer;
18091816
size_t batchn_byte_size;
18101817
TRITONSERVER_MemoryType memory_type;
@@ -1867,15 +1874,14 @@ ModelInstanceState::SetInputTensors(
18671874
TRITONSERVER_MemoryType dst_memory_type;
18681875
int64_t dst_memory_type_id;
18691876

1870-
// Batch inputs are always created on CPU
18711877
RESPOND_ALL_AND_SET_NULL_IF_ERROR(
18721878
(*responses), responses->size(),
18731879
collector->ProcessBatchInput(
1874-
batch_input, nullptr, 0, {{TRITONSERVER_MEMORY_CPU, 0}},
1875-
&dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1876-
&dst_memory_type_id));
1880+
batch_input, nullptr, 0, alloc_perference, &dst_buffer,
1881+
&dst_buffer_byte_size, &dst_memory_type, &dst_memory_type_id));
18771882

1878-
const auto torch_dtype = ConvertDataTypeToTorchType(batch_input.DataType());
1883+
const auto torch_dtype =
1884+
ConvertDataTypeToTorchType(batch_input.DataType());
18791885
torch::TensorOptions options{torch_dtype.second};
18801886
auto updated_options = options.device(torch::kCPU);
18811887

@@ -1944,9 +1950,11 @@ ModelInstanceState::ReadOutputTensors(
19441950

19451951
// Output tensors may not reside on the same device as model
19461952
torch::Device tensor_device = output_flat.device();
1947-
const auto memory_type = (tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
1948-
: TRITONSERVER_MEMORY_GPU;
1949-
const auto memory_id = (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
1953+
const auto memory_type = (tensor_device.type() == torch::kCPU)
1954+
? TRITONSERVER_MEMORY_CPU
1955+
: TRITONSERVER_MEMORY_GPU;
1956+
const auto memory_id =
1957+
(tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
19501958

19511959
// Batch output doesn't support string data type yet, as it is not trivial
19521960
// to parse string output
@@ -1963,16 +1971,16 @@ ModelInstanceState::ReadOutputTensors(
19631971
return TRITONSERVER_ErrorNew(
19641972
TRITONSERVER_ERROR_INVALID_ARG,
19651973
(std::string("output '") + name +
1966-
"' is a scalar which is not supported.")
1974+
"' is a scalar which is not supported.")
19671975
.c_str());
19681976
}
19691977

19701978
responder.ProcessTensor(
1971-
name, output_dtype, batchn_shape, output_buffer,
1972-
memory_type, memory_id);
1979+
name, output_dtype, batchn_shape, output_buffer, memory_type,
1980+
memory_id);
19731981
} else {
19741982
responder.ProcessBatchOutput(
1975-
name, *batch_output, output_buffer, memory_type, memory_id);
1983+
name, *batch_output, output_buffer, memory_type, memory_id);
19761984
}
19771985
} else if (output_tensors[op_index].isList()) {
19781986
// Custom handling for string/bytes tensor...

0 commit comments

Comments
 (0)