Skip to content

Commit ec65f0e

Browse files
committed
feat: add useMatmulForSingleBatch option to mlir::tosa::addTosaToLinalgPasses.
1 parent 608e8f7 commit ec65f0e

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@ namespace mlir {
2525
namespace tosa {
2626

2727
std::unique_ptr<Pass> createTosaToLinalg();
28-
std::unique_ptr<Pass> createTosaToLinalgNamed();
28+
std::unique_ptr<Pass>
29+
createTosaToLinalgNamed(bool useMatmulForSingleBatch = false);
2930

3031
/// Populates passes to convert from TOSA to Linalg on buffers. At the end of
3132
/// the pass, the function will only contain linalg ops or standard ops if the
3233
/// pipeline succeeds. The option to disable decompositions is available for
3334
/// benchmarking performance improvements from the canonicalizations.
3435
void addTosaToLinalgPasses(OpPassManager &pm,
35-
bool disableTosaDecompositions = false);
36+
bool disableTosaDecompositions = false,
37+
bool useMatmulForSingleBatch = false);
3638

3739
/// Populates conversion passes from TOSA dialect to Linalg dialect.
3840
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ namespace {
3737
struct TosaToLinalgNamed
3838
: public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> {
3939
public:
40+
TosaToLinalgNamed() = default;
41+
explicit TosaToLinalgNamed(bool useMatmulForSingleBatch) {
42+
this->useMatmulForSingleBatch = useMatmulForSingleBatch;
43+
}
44+
4045
void getDependentDialects(DialectRegistry &registry) const override {
4146
registry
4247
.insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
@@ -69,6 +74,7 @@ struct TosaToLinalgNamed
6974
};
7075
} // namespace
7176

72-
std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgNamed() {
73-
return std::make_unique<TosaToLinalgNamed>();
77+
std::unique_ptr<Pass>
78+
mlir::tosa::createTosaToLinalgNamed(bool useMatmulForSingleBatch) {
79+
return std::make_unique<TosaToLinalgNamed>(useMatmulForSingleBatch);
7480
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,17 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
7575
}
7676

7777
void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
78-
bool disableTosaDecompositions) {
78+
bool disableTosaDecompositions,
79+
bool useMatmulForSingleBatch) {
7980
// Optional decompositions are designed to benefit linalg.
8081
if (!disableTosaDecompositions)
8182
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
8283
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
8384

8485
pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
8586
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
86-
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
87+
pm.addNestedPass<func::FuncOp>(
88+
tosa::createTosaToLinalgNamed(useMatmulForSingleBatch));
8789
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
8890
// TODO: Remove pass that operates on const tensor and enable optionality
8991
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());

0 commit comments

Comments
 (0)