-
Notifications
You must be signed in to change notification settings - Fork 364
fix: Device casting issues with certain aten
operators
#1416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#include "torch/csrc/jit/ir/constants.h" | ||
#include "torch/csrc/jit/passes/subgraph_rewrite.h" | ||
|
||
#include "core/util/prelude.h" | ||
|
||
namespace torch_tensorrt { | ||
namespace core { | ||
namespace lowering { | ||
namespace passes { | ||
|
||
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) { | ||
std::string masked_fill_pattern = R"IR( | ||
graph(%self, %mask, %value): | ||
%out: Tensor = aten::masked_fill_(%self, %mask, %value) | ||
return (%out))IR"; | ||
|
||
// Calls to masked_fill_ often utilize CPU tensors, and as such | ||
// should be moved to gpu to avoid device mismatch errors | ||
|
||
// Separate string into portions to insert device name | ||
std::string clean_pattern_part_1 = R"IR( | ||
graph(%self, %mask, %value): | ||
%device: Device = prim::Constant[value=")IR"; | ||
|
||
std::string clean_pattern_part_2 = R"IR("]() | ||
%dtype: NoneType = prim::Constant() | ||
%false: bool = prim::Constant[value=0]() | ||
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false) | ||
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false) | ||
%out: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value) | ||
return (%out))IR"; | ||
|
||
auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; | ||
|
||
torch::jit::SubgraphRewriter masked_fill_rewriter; | ||
masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern); | ||
masked_fill_rewriter.runOnGraph(graph); | ||
LOG_GRAPH("After unpack and cast masked_fill_: " << *graph); | ||
} | ||
|
||
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) { | ||
std::string num_to_tensor_cast_pattern = R"IR( | ||
graph(%1: Scalar): | ||
%2: Tensor = prim::NumToTensor(%1) | ||
return (%2))IR"; | ||
|
||
// 0D Tensors are initialized on cpu, and need to be moved to gpu | ||
// to avoid device mismatch issues | ||
|
||
// Separate string into portions to insert device name | ||
std::string clean_pattern_part_1 = R"IR( | ||
graph(%1: Scalar): | ||
%2: Tensor = prim::NumToTensor(%1) | ||
%device: Device = prim::Constant[value=")IR"; | ||
|
||
std::string clean_pattern_part_2 = R"IR("]() | ||
%dtype: NoneType = prim::Constant() | ||
%false: bool = prim::Constant[value=0]() | ||
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false) | ||
return (%3))IR"; | ||
|
||
auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; | ||
Comment on lines
+51
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Had to use this paradigm instead of |
||
|
||
torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter; | ||
num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern); | ||
num_to_tensor_cast_rewriter.runOnGraph(graph); | ||
|
||
LOG_GRAPH("After unpack and cast NumToTensor: " << *graph); | ||
} | ||
|
||
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) { | ||
std::string full_cast_pattern = R"IR( | ||
graph(%1, %2, %3, %4, %5, %6): | ||
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6) | ||
return (%out))IR"; | ||
|
||
// Tensors created via aten::full are initialized on cpu, and need to be casted to gpu | ||
// to avoid device mismatch issues | ||
|
||
// Separate string into portions to insert device name | ||
std::string clean_pattern_part_1 = R"IR( | ||
graph(%1, %2, %3, %4, %5, %6): | ||
%device: Device = prim::Constant[value=")IR"; | ||
|
||
std::string clean_pattern_part_2 = R"IR("]() | ||
%out: Tensor = aten::full(%1, %2, %3, %4, %device, %6) | ||
return (%out))IR"; | ||
|
||
auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; | ||
|
||
torch::jit::SubgraphRewriter full_cast_rewriter; | ||
full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern); | ||
full_cast_rewriter.runOnGraph(graph); | ||
|
||
LOG_GRAPH("After unpack and cast full: " << *graph); | ||
} | ||
|
||
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph) { | ||
std::string scalar_implicit_cast_pattern = R"IR( | ||
graph(%1: Tensor): | ||
%2: Scalar = aten::ScalarImplicit(%1) | ||
return (%2))IR"; | ||
|
||
// ScalarImplicit can only unpack 0D tensors, whereas Tensors operated on by | ||
// TensorRT are padded to 1 dimension. aten::item() resolves this conflict | ||
std::string scalar_implicit_clean_pattern = R"IR( | ||
graph(%1: Tensor): | ||
%2: Scalar = aten::item(%1) | ||
return (%2))IR"; | ||
|
||
torch::jit::SubgraphRewriter scalar_implicit_cast_rewriter; | ||
scalar_implicit_cast_rewriter.RegisterRewritePattern(scalar_implicit_cast_pattern, scalar_implicit_clean_pattern); | ||
scalar_implicit_cast_rewriter.runOnGraph(graph); | ||
|
||
LOG_GRAPH("After unpack and cast full: " << *graph); | ||
} | ||
|
||
} // namespace passes | ||
} // namespace lowering | ||
} // namespace core | ||
} // namespace torch_tensorrt |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,16 +63,40 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr | |
CudaDevice curr_device = get_current_device(); | ||
LOG_DEBUG("Current Device: " << curr_device); | ||
|
||
// Generic Target Device Prefix | ||
std::string target_device = "cuda:"; | ||
|
||
if (is_switch_required(curr_device, compiled_engine->device_info)) { | ||
// Scan through available CUDA devices and set the CUDA device context correctly | ||
CudaDevice device = select_cuda_device(compiled_engine->device_info); | ||
set_cuda_device(device); | ||
|
||
std::string target_device = "cuda:" + std::to_string(device.id); | ||
// Target device is new device | ||
target_device += std::to_string(device.id); | ||
|
||
for (auto& in : inputs) { | ||
in = in.to(torch::Device(target_device)); | ||
} | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't need to be an else, could just be a second check. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated the else block to just assign the |
||
// Target device is current device | ||
target_device += std::to_string(curr_device.id); | ||
} | ||
|
||
// For each input, ensure its current device is the desired target device | ||
for (size_t i = 0; i < inputs.size(); i++) { | ||
at::Tensor* in = &inputs[i]; | ||
std::string current_tensor_device = in->device().str(); | ||
|
||
// If current device string does not match target device, display warning and move tensor accordingly | ||
if (current_tensor_device != target_device) { | ||
LOG_WARNING( | ||
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device | ||
<< " but should be on " << target_device << ". This tensor is being moved by the runtime but " | ||
<< "for performance considerations, ensure your inputs are all on GPU " | ||
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " | ||
<< "warning persists."); | ||
*in = in->to(torch::Device(target_device)); | ||
} | ||
} | ||
|
||
std::vector<void*> gpu_handles; | ||
|
Uh oh!
There was an error while loading. Please reload this page.