Skip to content

Commit 0fb364a

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Remove IndexedGenericOp support from LinalgToStandard...
after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612). Differential Revision: https://reviews.llvm.org/D102236
1 parent 96100f1 commit 0fb364a

File tree

3 files changed

+7
-67
lines changed

3 files changed

+7
-67
lines changed

mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ namespace linalg {
2828
// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
2929
// function. The implementation of the function can be either in the same module
3030
// or in an externally linked library.
31-
// This is a generic entry point for all LinalgOp, except for CopyOp and
32-
// IndexedGenericOp, for which more specialized patterns are provided.
31+
// This is a generic entry point for all LinalgOp, except for CopyOp, for which
32+
// more specialized patterns are provided.
3333
class LinalgOpToLibraryCallRewrite
3434
: public OpInterfaceRewritePattern<LinalgOp> {
3535
public:
@@ -58,16 +58,6 @@ class CopyTransposeRewrite : public OpRewritePattern<CopyOp> {
5858
PatternRewriter &rewriter) const override;
5959
};
6060

61-
/// Conversion pattern specialization for IndexedGenericOp, has special handling
62-
/// for the extra index operands.
63-
class IndexedGenericOpToLibraryCallRewrite
64-
: public OpRewritePattern<IndexedGenericOp> {
65-
public:
66-
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
67-
LogicalResult matchAndRewrite(IndexedGenericOp op,
68-
PatternRewriter &rewriter) const override;
69-
};
70-
7161
/// Populate the given list with patterns that convert from Linalg to Standard.
7262
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
7363

mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,6 @@ using namespace mlir::linalg;
2626
static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
2727
SmallVector<Type, 4> result;
2828
result.reserve(op->getNumOperands());
29-
if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
30-
auto *ctx = op->getContext();
31-
auto numLoops = indexedGenericOp.getNumLoops();
32-
result.reserve(op->getNumOperands() + numLoops);
33-
result.assign(numLoops, IndexType::get(ctx));
34-
}
3529
for (auto type : op->getOperandTypes()) {
3630
// The underlying descriptor type (e.g. LLVM) does not have layout
3731
// information. Canonicalizing the type at the level of std when going into
@@ -103,7 +97,11 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
10397
LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
10498
LinalgOp op, PatternRewriter &rewriter) const {
10599
// Only LinalgOp for which there is no specialized pattern go through this.
106-
if (isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
100+
if (isa<CopyOp>(op))
101+
return failure();
102+
103+
// Canonicalize indexed generic operations before library call conversion.
104+
if (isa<IndexedGenericOp>(op))
107105
return failure();
108106

109107
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
@@ -167,31 +165,6 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
167165
return success();
168166
}
169167

170-
LogicalResult
171-
mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
172-
IndexedGenericOp op, PatternRewriter &rewriter) const {
173-
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
174-
if (!libraryCallName)
175-
return failure();
176-
177-
// TODO: Use induction variables values instead of zeros, when
178-
// IndexedGenericOp is tiled.
179-
auto zero = rewriter.create<mlir::ConstantOp>(
180-
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
181-
auto indexedGenericOp = cast<IndexedGenericOp>(op);
182-
auto numLoops = indexedGenericOp.getNumLoops();
183-
SmallVector<Value, 4> operands;
184-
operands.reserve(numLoops + op.getNumOperands());
185-
for (unsigned i = 0; i < numLoops; ++i)
186-
operands.push_back(zero);
187-
for (auto operand : op.getOperands())
188-
operands.push_back(operand);
189-
rewriter.replaceOpWithNewOp<mlir::CallOp>(
190-
op, libraryCallName.getValue(), TypeRange(),
191-
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
192-
return success();
193-
}
194-
195168
/// Populate the given list with patterns that convert from Linalg to Standard.
196169
void mlir::linalg::populateLinalgToStandardConversionPatterns(
197170
RewritePatternSet &patterns) {
@@ -201,7 +174,6 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns(
201174
patterns.add<
202175
CopyOpToLibraryCallRewrite,
203176
CopyTransposeRewrite,
204-
IndexedGenericOpToLibraryCallRewrite,
205177
LinalgOpToLibraryCallRewrite>(patterns.getContext());
206178
// clang-format on
207179
}

mlir/test/Dialect/Linalg/standard.mlir

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -95,25 +95,3 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
9595
}
9696
// CHECK-LABEL: func @matmul_vec_impl(
9797
// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
98-
99-
#indexed_matmul_trait = {
100-
iterator_types = ["parallel", "parallel", "reduction"],
101-
indexing_maps = #matmul_accesses,
102-
library_call = "external_indexed_outerproduct_matmul"
103-
}
104-
func @matmul_vec_indexed(%A: !matrix_type_A,
105-
%B: !matrix_type_B,
106-
%C: !matrix_type_C) {
107-
linalg.indexed_generic #indexed_matmul_trait
108-
ins(%A, %B : !matrix_type_A, !matrix_type_B)
109-
outs(%C : !matrix_type_C) {
110-
^bb0(%i: index, %j: index, %k: index,
111-
%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
112-
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
113-
linalg.yield %d: !vector_type_C
114-
}
115-
return
116-
}
117-
// CHECK-LABEL: func @matmul_vec_indexed(
118-
// CHECK: %[[ZERO:.*]] = constant 0 : index
119-
// CHECK: call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}})

0 commit comments

Comments
 (0)