Skip to content

Commit b0c9bb6

Browse files
author
Anurag Dixit
committed
fix: Using toGraphFunction API to get graph
Signed-off-by: Anurag Dixit <[email protected]>
1 parent e3286fd commit b0c9bb6

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

core/lowering/passes/linear_to_addmm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "torch/csrc/jit/passes/guard_elimination.h"
88
#include "torch/csrc/jit/passes/peephole.h"
99
#include "torch/csrc/jit/runtime/graph_executor.h"
10+
#include "torch/csrc/jit/api/function_impl.h"
1011

1112
#include "core/util/prelude.h"
1213
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
@@ -34,7 +35,7 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
3435
continue;
3536
} else {
3637
torch::jit::WithInsertPoint guard(*it);
37-
std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function("linear").graph();
38+
std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction(decompose_funcs.get_function("linear")).graph();
3839
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
3940
new_output->setType(it->output()->type());
4041
it->output()->replaceAllUsesWith(new_output);

0 commit comments

Comments
 (0)