Skip to content

Commit 92e5ff8

Browse files
authored
Merge pull request #499 from NVIDIA/vit
fix: Fix linear lowering pass, lift layer_norm scale layer restriction and matmul layer nbdims restriction
2 parents 7daed40 + e4e4f8c commit 92e5ff8

File tree

8 files changed

+156
-30
lines changed

8 files changed

+156
-30
lines changed

core/compiler.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
182182
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
183183
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
184184
for (const torch::jit::script::Method& method : mod.get_methods()) {
185-
// Don't convert hidden methods
186-
if (method.name().rfind("_", 0)) {
185+
// Compile only forward methods. forward method contains the entire graph.
186+
if (method.name().compare("forward") == 0) {
187187
auto new_g = std::make_shared<torch::jit::Graph>();
188188
auto graph_and_parameters = lowering::Lower(mod, method.name());
189189

@@ -256,8 +256,8 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
256256
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
257257
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
258258
for (const torch::jit::script::Method& method : mod.get_methods()) {
259-
// Don't convert hidden methods
260-
if (method.name().rfind("_", 0)) {
259+
// Compile only forward methods. forward method contains the entire graph.
260+
if (method.name().compare("forward") == 0) {
261261
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
262262
auto new_g = std::make_shared<torch::jit::Graph>();
263263
AddEngineToGraph(new_mod, new_g, engine);

core/conversion/converters/impl/layer_norm.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,31 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
117117
}
118118

119119
auto power = Weights(ctx, at::ones(expand_size));
120-
auto scale_nd = ctx->net->addScaleNd(
121-
*div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1);
122-
scale_nd->setName((util::node_info(n) + "_scale_nd").c_str());
123-
auto scale_nd_out = scale_nd->getOutput(0);
124120

125-
ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out);
121+
auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0);
122+
auto scale_l = add_elementwise(
123+
ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str());
124+
125+
auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0);
126+
auto shift_l = add_elementwise(
127+
ctx,
128+
nvinfer1::ElementWiseOperation::kSUM,
129+
scale_l->getOutput(0),
130+
beta_tensor,
131+
(util::node_info(n) + "_shift").c_str());
132+
133+
auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0);
134+
auto power_l = add_elementwise(
135+
ctx,
136+
nvinfer1::ElementWiseOperation::kPOW,
137+
shift_l->getOutput(0),
138+
power_tensor,
139+
(util::node_info(n) + "_power").c_str());
140+
141+
power_l->setName((util::node_info(n) + "_scale_nd").c_str());
142+
auto power_l_out = power_l->getOutput(0);
143+
144+
ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out);
126145
return true;
127146
}});
128147

core/conversion/converters/impl/matrix_multiply.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "core/conversion/converters/converter_util.h"
12
#include "core/conversion/converters/converters.h"
23
#include "core/util/prelude.h"
34

@@ -13,10 +14,14 @@ auto mm_registrations TRTORCH_UNUSED =
1314
.pattern({"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
1415
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1516
auto self = args[0].ITensorOrFreeze(ctx);
16-
LOG_DEBUG("self tensor shape: " << self->getDimensions());
17-
1817
auto other = args[1].ITensorOrFreeze(ctx);
19-
LOG_DEBUG("other tensor shape: " << other->getDimensions());
18+
// Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
19+
// necessary.
20+
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
21+
self = addPadding(ctx, n, self, other->getDimensions().nbDims, false, false);
22+
} else {
23+
other = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false);
24+
}
2025

2126
auto mm_layer = ctx->net->addMatrixMultiply(
2227
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
@@ -73,4 +78,4 @@ auto mm_registrations TRTORCH_UNUSED =
7378
} // namespace converters
7479
} // namespace conversion
7580
} // namespace core
76-
} // namespace trtorch
81+
} // namespace trtorch

core/lowering/passes/linear_to_addmm.cpp

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,55 @@
1-
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
1+
2+
#include <torch/csrc/jit/runtime/operator.h>
3+
#include "torch/csrc/jit/ir/alias_analysis.h"
4+
#include "torch/csrc/jit/jit_log.h"
5+
#include "torch/csrc/jit/passes/constant_propagation.h"
6+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
7+
#include "torch/csrc/jit/passes/guard_elimination.h"
8+
#include "torch/csrc/jit/passes/peephole.h"
9+
#include "torch/csrc/jit/runtime/graph_executor.h"
210

311
#include "core/util/prelude.h"
12+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
413

514
namespace trtorch {
615
namespace core {
716
namespace lowering {
817
namespace passes {
918

19+
void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph) {
20+
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
21+
static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT(
22+
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
23+
return torch.matmul(self, mat1.t())
24+
)SCRIPT");
25+
26+
// Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
27+
auto block = graph->block();
28+
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
29+
auto n = *it;
30+
if (n->kind().toQualString() == std::string("aten::linear")) {
31+
auto input_values = n->inputs();
32+
// input_values[2] is the bias. If none, replace it with the decomposed linear graph.
33+
if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) {
34+
continue;
35+
} else {
36+
torch::jit::WithInsertPoint guard(*it);
37+
std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function("linear").graph();
38+
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
39+
new_output->setType(it->output()->type());
40+
it->output()->replaceAllUsesWith(new_output);
41+
it.destroyCurrent();
42+
}
43+
}
44+
}
45+
}
46+
1047
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
1148
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
1249
std::string flatten_linear_pattern = R"IR(
1350
graph(%input, %weight, %bias):
1451
%res = aten::linear(%input, %weight, %bias)
1552
return (%res))IR";
16-
std::string flatten_linear_bias_none_pattern = R"IR(
17-
graph(%input, %weight):
18-
%bias: Tensor? = prim::Constant()
19-
%res = aten::linear(%input, %weight, %bias)
20-
return (%res))IR";
2153

2254
std::string fused_linear = R"IR(
2355
graph(%input, %weight_t, %bias):
@@ -27,20 +59,13 @@ void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
2759
%b_f: Tensor = trt::const(%bias)
2860
%out: Tensor = aten::add(%b_f, %mm, %1)
2961
return (%out))IR";
30-
std::string fused_linear_bias_none = R"IR(
31-
graph(%input, %weight_t):
32-
%weight = aten::t(%weight_t)
33-
%mm: Tensor = aten::matmul(%input, %weight)
34-
return (%mm))IR";
62+
63+
// First find and replace aten::linear nodes with non-tensor bias values.
64+
replaceLinearWithBiasNonePattern(graph);
3565

3666
torch::jit::SubgraphRewriter flatten_linear_to_linear;
3767
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
3868
flatten_linear_to_linear.runOnGraph(graph);
39-
40-
torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
41-
flatten_linear_bias_none_to_linear.RegisterRewritePattern(flatten_linear_bias_none_pattern, fused_linear_bias_none);
42-
flatten_linear_bias_none_to_linear.runOnGraph(graph);
43-
LOG_GRAPH("Post linear to addmm: " << *graph);
4469
}
4570

4671
} // namespace passes

tests/core/conversion/converters/test_layer_norm.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,31 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast1Dims) {
118118

119119
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
120120
}
121+
122+
TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) {
123+
const auto graph = R"IR(
124+
graph(%0 : Tensor,
125+
%gamma: Float(197, 768),
126+
%beta: Float(197, 768)):
127+
%1: int = prim::Constant[value=768]()
128+
%4 : int[] = prim::ListConstruct(%1)
129+
%7 : bool = prim::Constant[value=0]()
130+
%8 : float = prim::Constant[value=1.0000000000000001e-05]()
131+
%9 : Tensor = aten::layer_norm(%0, %4, %gamma, %beta, %8, %7)
132+
return (%9))IR";
133+
134+
auto g = std::make_shared<torch::jit::Graph>();
135+
torch::jit::parseIR(graph, g.get());
136+
137+
auto in = at::randint(1, 10, {1, 197, 768}, {at::kCUDA});
138+
auto gamma = at::randint(1, 10, {768}, {at::kCUDA});
139+
auto beta = at::randint(1, 10, {768}, {at::kCUDA});
140+
141+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
142+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
143+
144+
params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
145+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
146+
147+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
148+
}

tests/core/conversion/converters/test_matrix_multiply.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,27 @@ TEST(Converters, ATenMMConvertsCorrectly) {
2626
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
2727
}
2828

29+
TEST(Converters, ATenMMWithDiffShapesConvertsCorrectly) {
30+
const auto graph = R"IR(
31+
graph(%0 : Tensor, %1 : Tensor):
32+
%2 : Tensor = aten::matmul(%0, %1)
33+
return (%2))IR";
34+
35+
auto g = std::make_shared<torch::jit::Graph>();
36+
torch::jit::parseIR(graph, g.get());
37+
38+
auto in1 = at::randint(0, 5, {2, 3}, {at::kCUDA});
39+
auto in2 = at::randint(0, 5, {3, 3, 2}, {at::kCUDA});
40+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
41+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});
42+
43+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
44+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});
45+
auto trt = trt_results[0].reshape_as(jit_results[0]);
46+
47+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
48+
}
49+
2950
TEST(Converters, ATenBMMConvertsCorrectly) {
3051
const auto graph = R"IR(
3152
graph(%0 : Tensor, %1 : Tensor):

tests/core/lowering/test_linear_to_addmm.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,27 @@ TEST(LoweringPasses, LinearToAddMM) {
3131
torch::jit::parseIR(target_graph, &*tg);
3232

3333
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
34-
}
34+
}
35+
36+
TEST(LoweringPasses, LinearToAddMMBiasNone) {
37+
std::string source_graph = R"IR(
38+
graph(%input, %weight):
39+
%bias : None = prim::Constant()
40+
%res = aten::linear(%input, %weight, %bias)
41+
return (%res))IR";
42+
std::string target_graph = R"IR(
43+
graph(%input, %weight_t):
44+
%weight = aten::t(%weight_t)
45+
%mm: Tensor = aten::matmul(%input, %weight)
46+
return (%mm))IR";
47+
48+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
49+
auto sg = std::make_shared<torch::jit::Graph>();
50+
torch::jit::parseIR(source_graph, &*sg);
51+
trtorch::core::lowering::passes::LinearToAddMM(sg);
52+
53+
auto tg = std::make_shared<torch::jit::Graph>();
54+
torch::jit::parseIR(target_graph, &*tg);
55+
56+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
57+
}

tests/modules/hub.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44
import torchvision.models as models
5+
import timm
56

67
models = {
78
"alexnet": {
@@ -64,6 +65,10 @@
6465
"faster_rcnn": {
6566
"model": models.detection.fasterrcnn_resnet50_fpn(pretrained=True),
6667
"path": "script"
68+
},
69+
"vit": {
70+
"model": timm.create_model('efficientnet_b0', pretrained=True),
71+
"path": "script"
6772
}
6873
}
6974

0 commit comments

Comments
 (0)