Skip to content

Commit 440775b

Browse files
committed
chore: Adding testcases
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 70cfe41 commit 440775b

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

core/conversion/converters/impl/matrix_multiply.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@ auto mm_registrations TRTORCH_UNUSED =
1414
.pattern({"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
1515
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1616
auto self = args[0].ITensorOrFreeze(ctx);
17-
LOG_DEBUG("self tensor shape: " << self->getDimensions());
18-
1917
auto other = args[1].ITensorOrFreeze(ctx);
20-
// "other" tensor should have same nbDims as self
21-
auto wt_tensor = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false);
22-
LOG_DEBUG("other tensor shape: " << wt_tensor->getDimensions());
18+
// Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
19+
// necessary.
20+
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
21+
self = addPadding(ctx, n, self, other->getDimensions().nbDims, false, false);
22+
} else {
23+
other = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false);
24+
}
2325

2426
auto mm_layer = ctx->net->addMatrixMultiply(
25-
*self, nvinfer1::MatrixOperation::kNONE, *wt_tensor, nvinfer1::MatrixOperation::kNONE);
27+
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
2628
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
2729
mm_layer->setName(util::node_info(n).c_str());
2830
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));

tests/core/conversion/converters/test_layer_norm.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,31 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast1Dims) {
118118

119119
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
120120
}
121+
122+
TEST(Converters, ATenLayerNormConvertsCorrectly3dInput1dNormalizedShape) {
123+
const auto graph = R"IR(
124+
graph(%0 : Tensor,
125+
%gamma: Float(197, 768),
126+
%beta: Float(197, 768)):
127+
%1: int = prim::Constant[value=768]()
128+
%4 : int[] = prim::ListConstruct(%1)
129+
%7 : bool = prim::Constant[value=0]()
130+
%8 : float = prim::Constant[value=1.0000000000000001e-05]()
131+
%9 : Tensor = aten::layer_norm(%0, %4, %gamma, %beta, %8, %7)
132+
return (%9))IR";
133+
134+
auto g = std::make_shared<torch::jit::Graph>();
135+
torch::jit::parseIR(graph, g.get());
136+
137+
auto in = at::randint(1, 10, {1, 197, 768}, {at::kCUDA});
138+
auto gamma = at::randint(1, 10, {768}, {at::kCUDA});
139+
auto beta = at::randint(1, 10, {768}, {at::kCUDA});
140+
141+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
142+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
143+
144+
params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta});
145+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
146+
147+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
148+
}

tests/core/conversion/converters/test_matrix_multiply.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,27 @@ TEST(Converters, ATenMMConvertsCorrectly) {
2626
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
2727
}
2828

29+
TEST(Converters, ATenMMWithDiffShapesConvertsCorrectly) {
30+
const auto graph = R"IR(
31+
graph(%0 : Tensor, %1 : Tensor):
32+
%2 : Tensor = aten::matmul(%0, %1)
33+
return (%2))IR";
34+
35+
auto g = std::make_shared<torch::jit::Graph>();
36+
torch::jit::parseIR(graph, g.get());
37+
38+
auto in1 = at::randint(0, 5, {2, 3}, {at::kCUDA});
39+
auto in2 = at::randint(0, 5, {3, 3, 2}, {at::kCUDA});
40+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
41+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});
42+
43+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
44+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});
45+
auto trt = trt_results[0].reshape_as(jit_results[0]);
46+
47+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
48+
}
49+
2950
TEST(Converters, ATenBMMConvertsCorrectly) {
3051
const auto graph = R"IR(
3152
graph(%0 : Tensor, %1 : Tensor):

0 commit comments

Comments
 (0)