Skip to content

Commit bfca5d0

Browse files
HennerMmc-nv
authored andcommitted
Add support for batch_input
1 parent fd29c6e commit bfca5d0

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

src/libtorch.cc

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
#include <stdint.h>
28-
2928
#include <cstdint>
3029
#include <exception>
3130

@@ -506,8 +505,9 @@ class ModelInstanceState : public BackendModelInstance {
506505
const std::string& control_kind, bool required, bool* have_control);
507506
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
508507
void AddInputToMap(
509-
NamingConvention naming_convention,
510-
const std::vector<std::string> allowed_inputs, const std::string& io_name,
508+
NamingConvention naming_convention,
509+
const std::vector<std::string> allowed_inputs,
510+
const std::string &io_name,
511511
const uint32_t index);
512512
TRITONSERVER_Error* ValidateOutputs();
513513
void Execute(
@@ -771,12 +771,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
771771
return nullptr; // success
772772
}
773773

774-
void
775-
ModelInstanceState::AddInputToMap(
776-
NamingConvention naming_convention,
777-
const std::vector<std::string> allowed_inputs, const std::string& io_name,
778-
const uint32_t index)
779-
{
774+
void ModelInstanceState::AddInputToMap(NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
780775
std::string deliminator = "__";
781776

782777
if (is_dict_input_) {
@@ -929,13 +924,11 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
929924
}
930925

931926
triton::common::TritonJson::Value batch_inputs;
932-
RETURN_IF_ERROR(
933-
model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
927+
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
934928
size_t i = 0;
935929
for (const auto& batch_input : StateForModel()->BatchInputs()) {
936930
for (const auto& input_name : batch_input.TargetNames()) {
937-
AddInputToMap(
938-
naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
931+
AddInputToMap(naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
939932
i++;
940933
}
941934
}
@@ -1883,16 +1876,12 @@ ModelInstanceState::SetInputTensors(
18831876
&dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
18841877
&dst_memory_type_id));
18851878

1886-
const auto torch_dtype =
1887-
ConvertDataTypeToTorchType(batch_input.DataType());
1879+
const auto torch_dtype = ConvertDataTypeToTorchType(batch_input.DataType());
18881880
torch::TensorOptions options{torch_dtype.second};
1889-
auto updated_options = (dst_memory_type == TRITONSERVER_MEMORY_GPU)
1890-
? options.device(torch::kCUDA, device_.index())
1891-
: options.device(torch::kCPU);
1881+
auto updated_options = options.device(torch::kCPU);
18921882

18931883
torch::Tensor input_tensor = torch::from_blob(
1894-
const_cast<char*>(dst_buffer), shape,
1895-
updated_options.dtype(torch_dtype.second));
1884+
const_cast<char*>(dst_buffer), shape, updated_options);
18961885
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
18971886
}
18981887
}

0 commit comments

Comments
 (0)