Skip to content

Commit b584f7a

Browse files
committed
chore: Add testcase for linear bias none lowering
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 440775b commit b584f7a

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

tests/core/lowering/test_linear_to_addmm.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,27 @@ TEST(LoweringPasses, LinearToAddMM) {
3131
torch::jit::parseIR(target_graph, &*tg);
3232

3333
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
34-
}
34+
}
35+
36+
TEST(LoweringPasses, LinearToAddMMBiasNone) {
37+
std::string source_graph = R"IR(
38+
graph(%input, %weight):
39+
%bias : None = prim::Constant()
40+
%res = aten::linear(%input, %weight, %bias)
41+
return (%res))IR";
42+
std::string target_graph = R"IR(
43+
graph(%input, %weight_t):
44+
%weight = aten::t(%weight_t)
45+
%mm: Tensor = aten::matmul(%input, %weight)
46+
return (%mm))IR";
47+
48+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
49+
auto sg = std::make_shared<torch::jit::Graph>();
50+
torch::jit::parseIR(source_graph, &*sg);
51+
trtorch::core::lowering::passes::LinearToAddMM(sg);
52+
53+
auto tg = std::make_shared<torch::jit::Graph>();
54+
torch::jit::parseIR(target_graph, &*tg);
55+
56+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
57+
}

0 commit comments

Comments
 (0)