Skip to content

fix: Device casting issues with certain aten operators #1416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/conversion/conversionctx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ cc_library(
deps = [
"@tensorrt//:nvinfer",
"//core/util:prelude",
"//core/ir",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand Down
11 changes: 2 additions & 9 deletions core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,21 @@
#include "torch/csrc/jit/ir/ir.h"

#include <cuda_runtime.h>
#include "core/ir/ir.h"
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace conversion {

struct Device {
nvinfer1::DeviceType device_type;
int64_t gpu_id;
int64_t dla_core;
bool allow_gpu_fallback;
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
};

struct BuilderSettings {
std::set<nvinfer1::DataType> enabled_precisions = {};
bool sparse_weights = false;
bool disable_tf32 = false;
bool refit = false;
bool debug = false;
bool truncate_long_and_double = false;
Device device;
ir::Device device;
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
nvinfer1::IInt8Calibrator* calibrator = nullptr;
uint64_t num_avg_timing_iters = 1;
Expand Down
8 changes: 8 additions & 0 deletions core/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ namespace torch_tensorrt {
namespace core {
namespace ir {

struct Device {
nvinfer1::DeviceType device_type;
int64_t gpu_id;
int64_t dla_core;
bool allow_gpu_fallback;
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
};

struct Input : torch::CustomClassHolder {
Input(){};
Input(
Expand Down
1 change: 1 addition & 0 deletions core/lowering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cc_library(
deps = [
"//core/lowering/passes",
"//core/util:prelude",
"//core/ir",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand Down
5 changes: 4 additions & 1 deletion core/lowering/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ set(HEADER_FILES
target_sources(${lib_name}
PRIVATE
${CXX_SRCS}
PUBLIC
$<TARGET_OBJECTS:core_ir>
$<TARGET_OBJECTS:core_util>
)

Expand All @@ -25,8 +27,9 @@ target_include_directories(${lib_name}

target_link_libraries(${lib_name}
PUBLIC
TensorRT::nvinfer
torch
PRIVATE
core_ir
core_util
)

Expand Down
4 changes: 4 additions & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
passes::SiluToSigmoidMultipication(g);
passes::RemoveSingleUse0DTensors(g);
passes::RemoveUnnecessaryCasts(g);
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
passes::ReplaceScalarImplicit(g);
passes::RewriteInputsWithParams(g, params);
LOG_GRAPH(*g);
}
Expand Down
6 changes: 6 additions & 0 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <memory>
#include "core/ir/ir.h"
#include "torch/csrc/jit/ir/ir.h"

namespace torch_tensorrt {
Expand All @@ -15,8 +16,13 @@ struct LowerInfo {
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
bool disable_cse = false;
ir::Device target_device;
std::vector<std::string> forced_fallback_modules;
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);

std::string getGPUDeviceString() {
return "cuda:" + std::to_string(target_device.gpu_id);
};
};

void LowerBlock(torch::jit::Block* b);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cc_library(
name = "passes",
srcs = [
"convNd_to_convolution.cpp",
"device_casting.cpp",
"exception_elimination.cpp",
"fuse_addmm_branches.cpp",
"linear_to_addmm.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
target_sources(${lib_name}
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/convNd_to_convolution.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/device_casting.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/exception_elimination.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/fuse_addmm_branches.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
Expand Down
121 changes: 121 additions & 0 deletions core/lowering/passes/device_casting.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include "torch/csrc/jit/ir/constants.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {

void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
std::string masked_fill_pattern = R"IR(
graph(%self, %mask, %value):
%out: Tensor = aten::masked_fill_(%self, %mask, %value)
return (%out))IR";

// Calls to masked_fill_ often utilize CPU tensors, and as such
// should be moved to gpu to avoid device mismatch errors

// Separate string into portions to insert device name
std::string clean_pattern_part_1 = R"IR(
graph(%self, %mask, %value):
%device: Device = prim::Constant[value=")IR";

std::string clean_pattern_part_2 = R"IR("]()
%dtype: NoneType = prim::Constant()
%false: bool = prim::Constant[value=0]()
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
%out: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value)
return (%out))IR";

auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;

torch::jit::SubgraphRewriter masked_fill_rewriter;
masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern);
masked_fill_rewriter.runOnGraph(graph);
LOG_GRAPH("After unpack and cast masked_fill_: " << *graph);
}

void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
std::string num_to_tensor_cast_pattern = R"IR(
graph(%1: Scalar):
%2: Tensor = prim::NumToTensor(%1)
return (%2))IR";

// 0D Tensors are initialized on cpu, and need to be moved to gpu
// to avoid device mismatch issues

// Separate string into portions to insert device name
std::string clean_pattern_part_1 = R"IR(
graph(%1: Scalar):
%2: Tensor = prim::NumToTensor(%1)
%device: Device = prim::Constant[value=")IR";

std::string clean_pattern_part_2 = R"IR("]()
%dtype: NoneType = prim::Constant()
%false: bool = prim::Constant[value=0]()
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
return (%3))IR";

auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
Comment on lines +51 to +62
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to use this paradigm instead of snprintf because the % symbols in the IR are registered as formatting for snprintf, which made it difficult to insert the device string


torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
num_to_tensor_cast_rewriter.runOnGraph(graph);

LOG_GRAPH("After unpack and cast NumToTensor: " << *graph);
}

void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
std::string full_cast_pattern = R"IR(
graph(%1, %2, %3, %4, %5, %6):
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
return (%out))IR";

// Tensors created via aten::full are initialized on cpu, and need to be casted to gpu
// to avoid device mismatch issues

// Separate string into portions to insert device name
std::string clean_pattern_part_1 = R"IR(
graph(%1, %2, %3, %4, %5, %6):
%device: Device = prim::Constant[value=")IR";

std::string clean_pattern_part_2 = R"IR("]()
%out: Tensor = aten::full(%1, %2, %3, %4, %device, %6)
return (%out))IR";

auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;

torch::jit::SubgraphRewriter full_cast_rewriter;
full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern);
full_cast_rewriter.runOnGraph(graph);

LOG_GRAPH("After unpack and cast full: " << *graph);
}

void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph) {
std::string scalar_implicit_cast_pattern = R"IR(
graph(%1: Tensor):
%2: Scalar = aten::ScalarImplicit(%1)
return (%2))IR";

// ScalarImplicit can only unpack 0D tensors, whereas Tensors operated on by
// TensorRT are padded to 1 dimension. aten::item() resolves this conflict
std::string scalar_implicit_clean_pattern = R"IR(
graph(%1: Tensor):
%2: Scalar = aten::item(%1)
return (%2))IR";

torch::jit::SubgraphRewriter scalar_implicit_cast_rewriter;
scalar_implicit_cast_rewriter.RegisterRewritePattern(scalar_implicit_cast_pattern, scalar_implicit_clean_pattern);
scalar_implicit_cast_rewriter.runOnGraph(graph);

LOG_GRAPH("After unpack and cast full: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
4 changes: 4 additions & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params);
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);

} // namespace passes
} // namespace lowering
Expand Down
26 changes: 25 additions & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,40 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
CudaDevice curr_device = get_current_device();
LOG_DEBUG("Current Device: " << curr_device);

// Generic Target Device Prefix
std::string target_device = "cuda:";

if (is_switch_required(curr_device, compiled_engine->device_info)) {
// Scan through available CUDA devices and set the CUDA device context correctly
CudaDevice device = select_cuda_device(compiled_engine->device_info);
set_cuda_device(device);

std::string target_device = "cuda:" + std::to_string(device.id);
// Target device is new device
target_device += std::to_string(device.id);

for (auto& in : inputs) {
in = in.to(torch::Device(target_device));
}
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to be an else, could just be a second check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the else block to just assign the cuda target device name, and now the runtime device check is applied as a second check

// Target device is current device
target_device += std::to_string(curr_device.id);
}

// For each input, ensure its current device is the desired target device
for (size_t i = 0; i < inputs.size(); i++) {
at::Tensor* in = &inputs[i];
std::string current_tensor_device = in->device().str();

// If current device string does not match target device, display warning and move tensor accordingly
if (current_tensor_device != target_device) {
LOG_WARNING(
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
<< " but should be on " << target_device << ". This tensor is being moved by the runtime but "
<< "for performance considerations, ensure your inputs are all on GPU "
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
<< "warning persists.");
*in = in->to(torch::Device(target_device));
}
}

std::vector<void*> gpu_handles;
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.debug = external.debug;
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;

TORCHTRT_CHECK(
!(external.require_full_compilation && (external.torch_executed_ops.size() > 0)),
Expand All @@ -130,10 +131,12 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
switch (external.device.device_type) {
case Device::DeviceType::kDLA:
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kDLA;
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kDLA;
break;
case Device::DeviceType::kGPU:
default:
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kGPU;
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kGPU;
}

switch (external.capability) {
Expand All @@ -150,6 +153,8 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {

internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id;
internal.convert_info.engine_settings.device.dla_core = external.device.dla_core;
internal.lower_info.target_device.gpu_id = external.device.gpu_id;
internal.lower_info.target_device.dla_core = external.device.dla_core;
internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
internal.convert_info.engine_settings.workspace_size = external.workspace_size;
internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size;
Expand Down
5 changes: 5 additions & 0 deletions tests/core/lowering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ lowering_test(
name = "test_conv1d_pass",
)

lowering_test(
name = "test_device_casting",
)

lowering_test(
name = "test_exception_elimination_pass",
)
Expand Down Expand Up @@ -95,6 +99,7 @@ test_suite(
name = "lowering_tests",
tests = [
":test_conv1d_pass",
":test_device_casting",
":test_exception_elimination_pass",
":test_linear_to_addmm",
":test_module_fallback_passes",
Expand Down
Loading