Skip to content

Commit 1dd2c93

Browse files
committed
[mlir][linalg] move isElementwise() to Linalg/Utils (NFC)
Differential Revision: https://reviews.llvm.org/D128398
1 parent f4a3df1 commit 1dd2c93

File tree

5 files changed

+47
-42
lines changed

5 files changed

+47
-42
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ class LinalgDependenceGraph;
3232
// General utilities
3333
//===----------------------------------------------------------------------===//
3434

35+
/// Check if all indexing maps are projected permutations.
36+
bool allIndexingsAreProjectedPermutation(LinalgOp op);
37+
38+
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
39+
bool hasOnlyScalarElementwiseOp(Region &r);
40+
41+
/// Check if a LinalgOp is an element-wise operation.
42+
bool isElementwise(LinalgOp op);
43+
3544
/// Check if `permutation` is a permutation of the range
3645
/// `[0, permutation.size())`.
3746
bool isPermutation(ArrayRef<int64_t> permutation);

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -417,48 +417,6 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
417417
llvm::to_vector<4>(returnTypes), op->getAttrs())};
418418
}
419419

420-
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
421-
static bool hasOnlyScalarElementwiseOp(Region &r) {
422-
if (!llvm::hasSingleElement(r))
423-
return false;
424-
for (Operation &op : r.front()) {
425-
if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
426-
linalg::IndexOp>(op) ||
427-
OpTrait::hasElementwiseMappableTraits(&op)) ||
428-
llvm::any_of(op.getResultTypes(),
429-
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
430-
return false;
431-
}
432-
return true;
433-
}
434-
435-
/// Returns `true` if all indexing maps of the linalg op are projected
436-
/// permutations.
437-
static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
438-
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
439-
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
440-
});
441-
}
442-
443-
// Return true if the op is an element-wise linalg op.
444-
static bool isElementwise(Operation *op) {
445-
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
446-
if (!linalgOp)
447-
return false;
448-
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
449-
return false;
450-
451-
if (!allIndexingsAreProjectedPermutation(linalgOp))
452-
return false;
453-
454-
// TODO: relax the restrictions on indexing map.
455-
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
456-
if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation())
457-
return false;
458-
}
459-
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
460-
}
461-
462420
/// Generic vectorization function that rewrites the body of a `linalgOp` into
463421
/// vector form. Generic vectorization proceeds as follows:
464422
/// 1. Verify the `linalgOp` has one non-empty region.

mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRLinalgUtils
99
MLIRAffineAnalysis
1010
MLIRAffineUtils
1111
MLIRArithmeticDialect
12+
MLIRFuncDialect
1213
MLIRIR
1314
MLIRLinalgDialect
1415
MLIRSCFDialect

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Affine/LoopUtils.h"
2020
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
2121
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
22+
#include "mlir/Dialect/Func/IR/FuncOps.h"
2223
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2324
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2425
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -141,6 +142,41 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
141142
namespace mlir {
142143
namespace linalg {
143144

145+
bool allIndexingsAreProjectedPermutation(LinalgOp op) {
146+
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
147+
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
148+
});
149+
}
150+
151+
bool hasOnlyScalarElementwiseOp(Region &r) {
152+
if (!llvm::hasSingleElement(r))
153+
return false;
154+
for (Operation &op : r.front()) {
155+
if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
156+
linalg::IndexOp>(op) ||
157+
OpTrait::hasElementwiseMappableTraits(&op)) ||
158+
llvm::any_of(op.getResultTypes(),
159+
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
160+
return false;
161+
}
162+
return true;
163+
}
164+
165+
bool isElementwise(LinalgOp op) {
166+
if (op.getNumLoops() != op.getNumParallelLoops())
167+
return false;
168+
169+
if (!allIndexingsAreProjectedPermutation(op))
170+
return false;
171+
172+
// TODO: relax the restrictions on indexing map.
173+
for (OpOperand *opOperand : op.getOutputOperands()) {
174+
if (!op.getTiedIndexingMap(opOperand).isPermutation())
175+
return false;
176+
}
177+
return hasOnlyScalarElementwiseOp(op->getRegion(0));
178+
}
179+
144180
bool isPermutation(ArrayRef<int64_t> permutation) {
145181
// Count the number of appearances for all indices.
146182
SmallVector<int64_t> indexCounts(permutation.size(), 0);

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7472,6 +7472,7 @@ cc_library(
74727472
":ArithmeticDialect",
74737473
":ArithmeticUtils",
74747474
":DialectUtils",
7475+
":FuncDialect",
74757476
":IR",
74767477
":LinalgAnalysis",
74777478
":LinalgDialect",

0 commit comments

Comments
 (0)