We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fe45fb0 commit fe0f318Copy full SHA for fe0f318
src/libtorch.cc
@@ -1884,7 +1884,9 @@ ModelInstanceState::SetInputTensors(
1884
const auto torch_dtype =
1885
ConvertDataTypeToTorchType(batch_input.DataType());
1886
torch::TensorOptions options{torch_dtype.second};
1887
- auto updated_options = options.device(torch::kCPU);
+ auto updated_options = (memory_type == TRITONSERVER_MEMORY_GPU)
1888
+ ? options.device(torch::kCUDA, device_.index())
1889
+ : options.device(torch::kCPU);
1890
1891
torch::Tensor input_tensor = torch::from_blob(
1892
const_cast<char*>(dst_buffer), shape, updated_options);
0 commit comments