Skip to content

Commit fe0f318

Browse files
dyastremskymc-nv
authored andcommitted
Use original device for created batch inputs
1 parent fe45fb0 commit fe0f318

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/libtorch.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,9 @@ ModelInstanceState::SetInputTensors(
18841884
const auto torch_dtype =
18851885
ConvertDataTypeToTorchType(batch_input.DataType());
18861886
torch::TensorOptions options{torch_dtype.second};
1887-
auto updated_options = options.device(torch::kCPU);
1887+
auto updated_options = (memory_type == TRITONSERVER_MEMORY_GPU)
1888+
? options.device(torch::kCUDA, device_.index())
1889+
: options.device(torch::kCPU);
18881890

18891891
torch::Tensor input_tensor = torch::from_blob(
18901892
const_cast<char*>(dst_buffer), shape, updated_options);

0 commit comments

Comments
 (0)