File tree Expand file tree Collapse file tree 1 file changed +24
-1
lines changed Expand file tree Collapse file tree 1 file changed +24
-1
lines changed Original file line number Diff line number Diff line change @@ -31,4 +31,27 @@ TEST(LoweringPasses, LinearToAddMM) {
31
31
torch::jit::parseIR (target_graph, &*tg);
32
32
33
33
ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
34
- }
34
+ }
35
+
36
+ TEST (LoweringPasses, LinearToAddMMBiasNone) {
37
+ std::string source_graph = R"IR(
38
+ graph(%input, %weight):
39
+ %bias : None = prim::Constant()
40
+ %res = aten::linear(%input, %weight, %bias)
41
+ return (%res))IR" ;
42
+ std::string target_graph = R"IR(
43
+ graph(%input, %weight_t):
44
+ %weight = aten::t(%weight_t)
45
+ %mm: Tensor = aten::matmul(%input, %weight)
46
+ return (%mm))IR" ;
47
+
48
+ trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
49
+ auto sg = std::make_shared<torch::jit::Graph>();
50
+ torch::jit::parseIR (source_graph, &*sg);
51
+ trtorch::core::lowering::passes::LinearToAddMM (sg);
52
+
53
+ auto tg = std::make_shared<torch::jit::Graph>();
54
+ torch::jit::parseIR (target_graph, &*tg);
55
+
56
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
57
+ }
You can’t perform that action at this time.
0 commit comments