Skip to content

Commit 06bb9cf

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Remove IndexedGenericOp support from LinalgInterchangePattern...
after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612). Differential Revision: https://reviews.llvm.org/D102245
1 parent a4db702 commit 06bb9cf

File tree

5 files changed

+50
-103
lines changed

5 files changed

+50
-103
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
213213
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
214214
/// integers, in the range 0..`op.rank` without duplications
215215
/// (i.e. `[1,1,2]` is an invalid permutation).
216-
void interchange(PatternRewriter &rewriter, LinalgOp op,
217-
ArrayRef<unsigned> interchangeVector);
216+
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
217+
ArrayRef<unsigned> interchangeVector);
218218

219219
/// Callback function type used to perform the allocation for the promoted
220220
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
@@ -363,11 +363,11 @@ LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
363363
// Preconditions that ensure the corresponding transformation succeeds and can
364364
// be applied as a rewrite pattern.
365365
//===----------------------------------------------------------------------===//
366-
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
367-
/// and `iterator_types` permutated according to `permutation`.
366+
/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
367+
/// permutated according to `permutation`.
368368
LogicalResult
369-
interchangeGenericLinalgOpPrecondition(Operation *op,
370-
ArrayRef<unsigned> interchangeVector);
369+
interchangeGenericOpPrecondition(GenericOp genericOp,
370+
ArrayRef<unsigned> interchangeVector);
371371

372372
/// Promote std.subviews feeding linalg operations.
373373
LogicalResult promoteSubviewsPrecondition(Operation *op,
@@ -630,18 +630,18 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
630630
};
631631

632632
///
633-
/// Linalg interchange patterns.
633+
/// Linalg generic interchage pattern.
634634
///
635635
/// Apply the `interchange` transformation as a pattern.
636636
/// `filter` controls LinalgTransformMarker matching and update when specified.
637637
/// See `interchange` for more details.
638-
struct LinalgBaseInterchangePattern : public RewritePattern {
639-
LinalgBaseInterchangePattern(
640-
StringRef opName, MLIRContext *context,
641-
ArrayRef<unsigned> interchangeVector,
638+
struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
639+
using OpRewritePattern<GenericOp>::OpRewritePattern;
640+
GenericOpInterchangePattern(
641+
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
642642
LinalgTransformationFilter filter = LinalgTransformationFilter(),
643643
PatternBenefit benefit = 1);
644-
LogicalResult matchAndRewrite(Operation *op,
644+
LogicalResult matchAndRewrite(GenericOp genericOp,
645645
PatternRewriter &rewriter) const override;
646646

647647
private:
@@ -651,16 +651,6 @@ struct LinalgBaseInterchangePattern : public RewritePattern {
651651
SmallVector<unsigned, 8> interchangeVector;
652652
};
653653

654-
template <typename OpTy>
655-
struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
656-
LinalgInterchangePattern(
657-
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
658-
LinalgTransformationFilter filter = LinalgTransformationFilter(),
659-
PatternBenefit benefit = 1)
660-
: LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
661-
interchangeVector, filter, benefit) {}
662-
};
663-
664654
///
665655
/// Linalg promotion patterns.
666656
///

mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,68 +32,65 @@
3232
using namespace mlir;
3333
using namespace mlir::linalg;
3434

35-
LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
36-
Operation *op, ArrayRef<unsigned> interchangeVector) {
37-
// Transformation applies to generic ops only.
38-
if (!isa<GenericOp, IndexedGenericOp>(op))
39-
return failure();
40-
LinalgOp linalgOp = cast<LinalgOp>(op);
35+
LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
36+
GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
4137
// Interchange vector must be non-empty and match the number of loops.
4238
if (interchangeVector.empty() ||
43-
linalgOp.getNumLoops() != interchangeVector.size())
39+
genericOp.getNumLoops() != interchangeVector.size())
4440
return failure();
4541
// Permutation map must be invertible.
46-
if (!inversePermutation(
47-
AffineMap::getPermutationMap(interchangeVector, op->getContext())))
42+
if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector,
43+
genericOp.getContext())))
4844
return failure();
4945
return success();
5046
}
5147

52-
void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op,
53-
ArrayRef<unsigned> interchangeVector) {
48+
void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
49+
GenericOp genericOp,
50+
ArrayRef<unsigned> interchangeVector) {
5451
// 1. Compute the inverse permutation map.
55-
MLIRContext *context = op.getContext();
52+
MLIRContext *context = genericOp.getContext();
5653
AffineMap permutationMap = inversePermutation(
5754
AffineMap::getPermutationMap(interchangeVector, context));
5855
assert(permutationMap && "expected permutation to be invertible");
59-
assert(interchangeVector.size() == op.getNumLoops() &&
56+
assert(interchangeVector.size() == genericOp.getNumLoops() &&
6057
"expected interchange vector to have entry for every loop");
6158

6259
// 2. Compute the interchanged indexing maps.
6360
SmallVector<Attribute, 4> newIndexingMaps;
64-
ArrayRef<Attribute> indexingMaps = op.indexing_maps().getValue();
65-
for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
61+
ArrayRef<Attribute> indexingMaps = genericOp.indexing_maps().getValue();
62+
for (unsigned i = 0, e = genericOp.getNumShapedOperands(); i != e; ++i) {
6663
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
6764
if (!permutationMap.isEmpty())
6865
m = m.compose(permutationMap);
6966
newIndexingMaps.push_back(AffineMapAttr::get(m));
7067
}
71-
op->setAttr(getIndexingMapsAttrName(),
72-
ArrayAttr::get(context, newIndexingMaps));
68+
genericOp->setAttr(getIndexingMapsAttrName(),
69+
ArrayAttr::get(context, newIndexingMaps));
7370

7471
// 3. Compute the interchanged iterator types.
75-
ArrayRef<Attribute> itTypes = op.iterator_types().getValue();
72+
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
7673
SmallVector<Attribute, 4> itTypesVector;
7774
llvm::append_range(itTypesVector, itTypes);
7875
applyPermutationToVector(itTypesVector, interchangeVector);
79-
op->setAttr(getIteratorTypesAttrName(),
80-
ArrayAttr::get(context, itTypesVector));
76+
genericOp->setAttr(getIteratorTypesAttrName(),
77+
ArrayAttr::get(context, itTypesVector));
8178

8279
// 4. Transform the index operations by applying the permutation map.
83-
if (op.hasIndexSemantics()) {
80+
if (genericOp.hasIndexSemantics()) {
8481
// TODO: Remove the assertion and add a getBody() method to LinalgOp
8582
// interface once every LinalgOp has a body.
86-
assert(op->getNumRegions() == 1 &&
87-
op->getRegion(0).getBlocks().size() == 1 &&
83+
assert(genericOp->getNumRegions() == 1 &&
84+
genericOp->getRegion(0).getBlocks().size() == 1 &&
8885
"expected generic operation to have one block.");
89-
Block &block = op->getRegion(0).front();
86+
Block &block = genericOp->getRegion(0).front();
9087
OpBuilder::InsertionGuard guard(rewriter);
9188
for (IndexOp indexOp :
9289
llvm::make_early_inc_range(block.getOps<IndexOp>())) {
9390
rewriter.setInsertionPoint(indexOp);
9491
SmallVector<Value> allIndices;
95-
allIndices.reserve(op.getNumLoops());
96-
llvm::transform(llvm::seq<uint64_t>(0, op.getNumLoops()),
92+
allIndices.reserve(genericOp.getNumLoops());
93+
llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
9794
std::back_inserter(allIndices), [&](uint64_t dim) {
9895
return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
9996
});

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -393,30 +393,26 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
393393
return success();
394394
}
395395

396-
/// Linalg base interchange pattern.
397-
mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
398-
StringRef opName, MLIRContext *context,
399-
ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter filter,
400-
PatternBenefit benefit)
401-
: RewritePattern(opName, benefit, context, {}), filter(filter),
396+
/// Linalg generic interchange pattern.
397+
mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
398+
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
399+
LinalgTransformationFilter filter, PatternBenefit benefit)
400+
: OpRewritePattern(context, benefit), filter(filter),
402401
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
403402

404-
LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
405-
Operation *op, PatternRewriter &rewriter) const {
406-
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
407-
if (!linalgOp)
403+
LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
404+
GenericOp genericOp, PatternRewriter &rewriter) const {
405+
if (failed(filter.checkAndNotify(rewriter, genericOp)))
408406
return failure();
409-
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
410-
return failure();
411-
if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
407+
if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
412408
return failure();
413409

414410
// TODO: figure out how this interplays with named ops. In particular this
415411
// should break the named op property.
416-
rewriter.updateRootInPlace(op, [&]() {
417-
interchange(rewriter, linalgOp, interchangeVector);
412+
rewriter.updateRootInPlace(genericOp, [&]() {
413+
interchangeGenericOp(rewriter, genericOp, interchangeVector);
418414
// New filter if specified.
419-
filter.replaceLinalgTransformationFilter(rewriter, op);
415+
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
420416
});
421417
return success();
422418
}

mlir/test/Dialect/Linalg/transform-patterns.mlir

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -125,37 +125,6 @@ func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
125125
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
126126
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
127127

128-
#indexed_matmul_trait = {
129-
args_in = 2,
130-
args_out = 1,
131-
indexing_maps = #matmul_accesses,
132-
library_call = "linalg_matmul_indexed",
133-
iterator_types = ["parallel", "parallel", "reduction"]
134-
}
135-
func @permute_generic_indexed(
136-
%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
137-
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
138-
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
139-
linalg.indexed_generic #indexed_matmul_trait
140-
ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
141-
memref<?x?xf32, offset: ?, strides: [?, 1]>)
142-
outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
143-
^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
144-
%d = mulf %a, %b: f32
145-
%e = addf %c, %d: f32
146-
linalg.yield %e: f32
147-
}
148-
return
149-
}
150-
// CHECK-LABEL: func @permute_generic_indexed
151-
// CHECK: linalg.indexed_generic {
152-
// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
153-
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"],
154-
// CHECK-SAME: library_call = "linalg_matmul_indexed"}
155-
// CHECK: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
156-
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
157-
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
158-
159128
func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
160129
%x: memref<?xf32, offset: ?, strides: [1]>,
161130
%y: memref<?xf32, offset: ?, strides: [1]>) {

mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,9 @@ static void applyPatterns(FuncOp funcOp) {
194194
.addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
195195

196196
//===--------------------------------------------------------------------===//
197-
// Linalg generic permutation patterns.
197+
// Linalg generic interchange pattern.
198198
//===--------------------------------------------------------------------===//
199-
patterns.add<LinalgInterchangePattern<GenericOp>>(
200-
ctx,
201-
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
202-
LinalgTransformationFilter(ArrayRef<Identifier>{},
203-
Identifier::get("PERMUTED", ctx)));
204-
patterns.add<LinalgInterchangePattern<IndexedGenericOp>>(
199+
patterns.add<GenericOpInterchangePattern>(
205200
ctx,
206201
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
207202
LinalgTransformationFilter(ArrayRef<Identifier>{},
@@ -551,7 +546,7 @@ static void applyInterchangePattern(FuncOp funcOp,
551546
ArrayRef<unsigned> interchangeVector) {
552547
MLIRContext *context = funcOp.getContext();
553548
RewritePatternSet interchangePattern(context);
554-
interchangePattern.add<LinalgInterchangePattern<GenericOp>>(
549+
interchangePattern.add<GenericOpInterchangePattern>(
555550
context, interchangeVector,
556551
LinalgTransformationFilter(ArrayRef<Identifier>{},
557552
Identifier::get("interchange", context)));

0 commit comments

Comments
 (0)