Skip to content

Commit 1b65137

Browse files
committed
fix: Bugfix in Linear-to-AddMM Fusion lowering pass
- Fix 2 bugs in linear-to-addmm lowering pass: - Lowering pass did not explore nested sub-blocks of a node, of the sort contained in `prim::If` when `bias=None` - Lowering pass did not insert fused linear code inside sub-blocks of `prim::If` even when the original function call occurred within such a block - The latter causes issues when the control-flow switches between two versions of `aten::linear`, only one of which is a valid operation. Thus, evaluating both branches can cause compilation to crash, as invalid Tensor shapes can be encountered - Update implementation to run recursively through all nested blocks within all nodes - Update implementation to remove the use of `RegisterRewritePattern` paradigm for Tensor biases, as the rewrite does not always place the subgraph in the desired location - Add regression test cases to isolate both bugs
1 parent 09bef46 commit 1b65137

File tree

2 files changed

+119
-27
lines changed

2 files changed

+119
-27
lines changed

core/lowering/passes/linear_to_addmm.cpp

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "core/util/prelude.h"
44
#include "torch/csrc/jit/api/function_impl.h"
55
#include "torch/csrc/jit/ir/alias_analysis.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
67
#include "torch/csrc/jit/jit_log.h"
78
#include "torch/csrc/jit/passes/constant_propagation.h"
89
#include "torch/csrc/jit/passes/dead_code_elimination.h"
@@ -16,26 +17,58 @@ namespace core {
1617
namespace lowering {
1718
namespace passes {
1819

19-
void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph) {
20+
void replaceLinear(torch::jit::Block* block) {
2021
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
2122
static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT(
2223
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
2324
return torch.matmul(self, mat1.t())
2425
)SCRIPT");
2526

26-
// Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
27-
auto block = graph->block();
27+
// Define graph format for aten::linear with Tensor-type bias
28+
std::string fused_linear = R"IR(
29+
graph(%input, %weight, %bias):
30+
%1: int = prim::Constant[value=1]()
31+
%weight = aten::t(%weight)
32+
%mm: Tensor = aten::matmul(%input, %weight)
33+
%b_f: Tensor = trt::const(%bias)
34+
%out: Tensor = aten::add(%b_f, %mm, %1)
35+
return (%out))IR";
36+
37+
// Iterate through nodes in block, seaching for aten::linear
2838
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
2939
auto n = *it;
30-
if (n->kind().toQualString() == std::string("aten::linear")) {
40+
41+
// Recursively explore nested blocks, such as those arising from prim::If
42+
for (auto block : n->blocks()) {
43+
replaceLinear(block);
44+
}
45+
46+
if ((n->kind().toQualString() == std::string("aten::linear")) && (n->inputs().size() >= 3)) {
3147
auto input_values = n->inputs();
32-
// input_values[2] is the bias. If none, replace it with the decomposed linear graph.
48+
49+
// input_values[2] is the bias
50+
// If Tensor, replace with fused-bias decomposed graph
51+
// Otherwise, replace it with the no-bias decomposed linear graph.
3352
if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) {
34-
continue;
53+
torch::jit::WithInsertPoint guard(*it);
54+
55+
// Initialize new fused subgraph from IR code above
56+
auto fused_g = std::make_shared<torch::jit::Graph>();
57+
torch::jit::parseIR(fused_linear, fused_g.get());
58+
59+
// Insert subgraph in place of aten::linear, replacing inputs and outputs accordingly
60+
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0);
61+
new_output->setType(it->output()->type());
62+
it->output()->replaceAllUsesWith(new_output);
63+
it.destroyCurrent();
3564
} else {
3665
torch::jit::WithInsertPoint guard(*it);
66+
67+
// Initialized decomposed graph without bias term
3768
std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction(decompose_funcs.get_function("linear")).graph();
3869
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
70+
71+
// Insert function in place of aten::linear, replacing inputs and outputs accordingly
3972
new_output->setType(it->output()->type());
4073
it->output()->replaceAllUsesWith(new_output);
4174
it.destroyCurrent();
@@ -45,27 +78,8 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
4578
}
4679

4780
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
48-
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
49-
std::string flatten_linear_pattern = R"IR(
50-
graph(%input, %weight, %bias):
51-
%res = aten::linear(%input, %weight, %bias)
52-
return (%res))IR";
53-
54-
std::string fused_linear = R"IR(
55-
graph(%input, %weight_t, %bias):
56-
%1: int = prim::Constant[value=1]()
57-
%weight = aten::t(%weight_t)
58-
%mm: Tensor = aten::matmul(%input, %weight)
59-
%b_f: Tensor = trt::const(%bias)
60-
%out: Tensor = aten::add(%b_f, %mm, %1)
61-
return (%out))IR";
62-
63-
// First find and replace aten::linear nodes with non-tensor bias values.
64-
replaceLinearWithBiasNonePattern(graph);
65-
66-
torch::jit::SubgraphRewriter flatten_linear_to_linear;
67-
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
68-
flatten_linear_to_linear.runOnGraph(graph);
81+
// Recursively find and replace all instances of aten::linear with the corresponding decomposed form
82+
replaceLinear(graph->block());
6983
}
7084

7185
} // namespace passes

tests/core/lowering/test_linear_to_addmm.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,81 @@ TEST(LoweringPasses, LinearToAddMMBiasNone) {
5757

5858
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
5959
}
60+
61+
TEST(LoweringPasses, LinearToAddMMBiasNoneGraphRun) {
62+
std::string source_graph = R"IR(
63+
graph(%input, %weight):
64+
%biasNone : None = prim::Constant()
65+
%true : bool = prim::Constant[value=1]()
66+
%invalid_weight : Tensor = aten::t(%weight)
67+
%4 : Tensor = prim::If(%true)
68+
block0():
69+
%res = aten::linear(%input, %weight, %biasNone)
70+
-> (%res)
71+
block1():
72+
%res = aten::linear(%input, %invalid_weight, %biasNone)
73+
-> (%res)
74+
return (%4))IR";
75+
76+
// This regression test case ensures the Linear-to-AddMM lowering pass satisfies two constraints for non-Tensor bias:
77+
// 1. It recursively resolves sub-blocks within the node, replacing sub-blocks to be converted as well
78+
// 2. It does not pre-evaluate branches of the block which may have invalid operations
79+
80+
auto g = std::make_shared<torch::jit::Graph>();
81+
torch::jit::parseIR(source_graph, g.get());
82+
83+
auto in_0 = at::rand({8, 7}, {at::kCUDA});
84+
auto in_1 = at::rand({8, 7}, {at::kCUDA});
85+
86+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
87+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
88+
89+
torch_tensorrt::core::lowering::passes::LinearToAddMM(g);
90+
91+
LOG_DEBUG(g);
92+
93+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
94+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
95+
96+
ASSERT_TRUE(
97+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
98+
}
99+
100+
TEST(LoweringPasses, LinearToAddMMBiasGraphRun) {
101+
std::string source_graph = R"IR(
102+
graph(%input, %weight, %bias):
103+
%true : bool = prim::Constant[value=1]()
104+
%invalid_weight : Tensor = aten::t(%weight)
105+
%4 : Tensor = prim::If(%true)
106+
block0():
107+
%res = aten::linear(%input, %weight, %bias)
108+
-> (%res)
109+
block1():
110+
%res = aten::linear(%input, %invalid_weight, %bias)
111+
-> (%res)
112+
return (%4))IR";
113+
114+
// This regression test case ensures the Linear-to-AddMM lowering pass satisfies two constraints for Tensor bias:
115+
// 1. It recursively resolves sub-blocks within the node, replacing sub-blocks to be converted as well
116+
// 2. It does not pre-evaluate branches of the block which may have invalid operations
117+
118+
auto g = std::make_shared<torch::jit::Graph>();
119+
torch::jit::parseIR(source_graph, g.get());
120+
121+
auto in_0 = at::rand({8, 7}, {at::kCUDA});
122+
auto in_1 = at::rand({8, 7}, {at::kCUDA});
123+
auto in_2 = at::rand({8, 8}, {at::kCUDA});
124+
125+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
126+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1, in_2});
127+
128+
torch_tensorrt::core::lowering::passes::LinearToAddMM(g);
129+
130+
LOG_DEBUG(g);
131+
132+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
133+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1, in_2});
134+
135+
ASSERT_TRUE(
136+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
137+
}

0 commit comments

Comments
 (0)