Skip to content

Commit d741d2c

Browse files
Fix: aten::matmul converter behavior with 1d tensors (#2450)
1 parent e834034 commit d741d2c

File tree

2 files changed

+160
-11
lines changed

2 files changed

+160
-11
lines changed

core/conversion/converters/impl/matrix_multiply.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,49 @@ auto mm_registrations TORCHTRT_UNUSED =
1616
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1717
auto self = args[0].ITensorOrFreeze(ctx);
1818
auto other = args[1].ITensorOrFreeze(ctx);
19+
20+
auto selfDims = self->getDimensions().nbDims;
21+
auto otherDims = other->getDimensions().nbDims;
22+
23+
bool squeezeFront = false;
24+
bool squeezeBack = false;
25+
26+
if (selfDims == 1 && selfDims < otherDims) {
27+
squeezeFront = true;
28+
} else if (otherDims == 1 && otherDims < selfDims) {
29+
// Append a 1 to the end of the shape before padding front to match self
30+
other = addPadding(ctx, n, other, 2, true, false);
31+
otherDims = other->getDimensions().nbDims;
32+
squeezeBack = true;
33+
}
34+
1935
// Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
2036
// necessary.
21-
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
22-
self = addPadding(ctx, n, self, other->getDimensions().nbDims, false, false);
23-
} else {
24-
other = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false);
37+
if (selfDims < otherDims) {
38+
self = addPadding(ctx, n, self, otherDims, false, false);
39+
} else if (otherDims < selfDims) {
40+
other = addPadding(ctx, n, other, selfDims, false, false);
2541
}
2642

2743
auto mm_layer = ctx->net->addMatrixMultiply(
2844
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
2945

3046
TORCHTRT_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
3147
mm_layer->setName(util::node_info(n).c_str());
32-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
48+
auto out = mm_layer->getOutput(0);
49+
50+
if (squeezeFront || squeezeBack) {
51+
auto squeezeDimOffset = squeezeFront ? 2 : 1;
52+
auto reshapeDims =
53+
util::squeezeDims(out->getDimensions(), out->getDimensions().nbDims - squeezeDimOffset);
54+
auto shuffle_layer = ctx->net->addShuffle(*out);
55+
LOG_DEBUG("Squeezing matmul output for 1d correction: " << reshapeDims);
56+
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
57+
shuffle_layer->setReshapeDimensions(reshapeDims);
58+
shuffle_layer->setName((util::node_info(n) + "_squeeze").c_str());
59+
out = shuffle_layer->getOutput(0);
60+
}
61+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out);
3362

3463
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
3564
return true;

tests/core/conversion/converters/test_matrix_multiply.cpp

Lines changed: 126 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ TEST(Converters, ATenMMConvertsCorrectly) {
2121

2222
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
2323
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
24-
auto trt = trt_results[0].reshape_as(jit_results[0]);
2524

26-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
25+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
2726
}
2827

2928
TEST(Converters, ATenMMWithDiffShapesConvertsCorrectly) {
@@ -42,9 +41,131 @@ TEST(Converters, ATenMMWithDiffShapesConvertsCorrectly) {
4241

4342
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
4443
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
45-
auto trt = trt_results[0].reshape_as(jit_results[0]);
4644

47-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
45+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
46+
}
47+
48+
TEST(Converters, ATenMM1d2dConvertsCorrectly) {
49+
const auto graph = R"IR(
50+
graph(%0 : Tensor, %1 : Tensor):
51+
%2 : Tensor = aten::matmul(%0, %1)
52+
return (%2))IR";
53+
54+
auto g = std::make_shared<torch::jit::Graph>();
55+
torch::jit::parseIR(graph, g.get());
56+
57+
auto in1 = at::randint(0, 5, {10}, {at::kCUDA});
58+
auto in2 = at::randint(0, 5, {10, 1}, {at::kCUDA});
59+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
60+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
61+
62+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
63+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
64+
65+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
66+
}
67+
68+
TEST(Converters, ATenMM1d3dConvertsCorrectly) {
69+
const auto graph = R"IR(
70+
graph(%0 : Tensor, %1 : Tensor):
71+
%2 : Tensor = aten::matmul(%0, %1)
72+
return (%2))IR";
73+
74+
auto g = std::make_shared<torch::jit::Graph>();
75+
torch::jit::parseIR(graph, g.get());
76+
77+
auto in1 = at::randint(0, 5, {10}, {at::kCUDA});
78+
auto in2 = at::randint(0, 5, {2, 10, 8}, {at::kCUDA});
79+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
80+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
81+
82+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
83+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
84+
85+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
86+
}
87+
88+
TEST(Converters, ATenMM1d4dConvertsCorrectly) {
89+
const auto graph = R"IR(
90+
graph(%0 : Tensor, %1 : Tensor):
91+
%2 : Tensor = aten::matmul(%0, %1)
92+
return (%2))IR";
93+
94+
auto g = std::make_shared<torch::jit::Graph>();
95+
torch::jit::parseIR(graph, g.get());
96+
97+
auto in1 = at::randint(0, 5, {10}, {at::kCUDA});
98+
auto in2 = at::randint(0, 5, {2, 3, 10, 8}, {at::kCUDA});
99+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
100+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
101+
102+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
103+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
104+
105+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
106+
}
107+
108+
TEST(Converters, ATenMM3d1dConvertsCorrectly) {
109+
const auto graph = R"IR(
110+
graph(%0 : Tensor, %1 : Tensor):
111+
%2 : Tensor = aten::matmul(%0, %1)
112+
return (%2))IR";
113+
114+
auto g = std::make_shared<torch::jit::Graph>();
115+
torch::jit::parseIR(graph, g.get());
116+
117+
auto in1 = at::randint(0, 5, {2, 10, 8}, {at::kCUDA});
118+
auto in2 = at::randint(0, 5, {8}, {at::kCUDA});
119+
120+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
121+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
122+
123+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
124+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
125+
126+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
127+
}
128+
129+
TEST(Converters, ATenMM2d1dConvertsCorrectly) {
130+
const auto graph = R"IR(
131+
graph(%0 : Tensor, %1 : Tensor):
132+
%2 : Tensor = aten::matmul(%0, %1)
133+
return (%2))IR";
134+
135+
auto g = std::make_shared<torch::jit::Graph>();
136+
torch::jit::parseIR(graph, g.get());
137+
138+
auto in1 = at::randint(0, 5, {1, 10}, {at::kCUDA});
139+
auto in2 = at::randint(0, 5, {10}, {at::kCUDA});
140+
141+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
142+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
143+
144+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
145+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
146+
147+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
148+
}
149+
150+
TEST(Converters, ATenMM4d1dConvertsCorrectly) {
151+
const auto graph = R"IR(
152+
graph(%0 : Tensor, %1 : Tensor):
153+
%2 : Tensor = aten::matmul(%0, %1)
154+
return (%2))IR";
155+
156+
auto g = std::make_shared<torch::jit::Graph>();
157+
torch::jit::parseIR(graph, g.get());
158+
159+
auto in1 = at::randint(0, 5, {2, 3, 10, 8}, {at::kCUDA});
160+
auto in2 = at::randint(0, 5, {8}, {at::kCUDA});
161+
162+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
163+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
164+
165+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
166+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
167+
168+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
48169
}
49170

50171
TEST(Converters, ATenBMMConvertsCorrectly) {
@@ -63,9 +184,8 @@ TEST(Converters, ATenBMMConvertsCorrectly) {
63184

64185
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
65186
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
66-
auto trt = trt_results[0].reshape_as(jit_results[0]);
67187

68-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
188+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
69189
}
70190

71191
TEST(Converters, ATenBADDBMMConvertsCorrectly) {

0 commit comments

Comments
 (0)