Skip to content

Commit e0dc3db

Browse files
[mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 1/n - LinalgToStandard.cpp
This revision belongs to a series of patches that reduce reliance of Linalg transformations on templated rewrite and conversion patterns. Instead, this uses a MatchAnyTag pattern for the vast majority of cases and dispatches internally. Differential Revision: https://reviews.llvm.org/D89133
1 parent df295fa commit e0dc3db

File tree

5 files changed

+161
-166
lines changed

5 files changed

+161
-166
lines changed

mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,10 @@
1212
#include "mlir/Transforms/DialectConversion.h"
1313

1414
namespace mlir {
15-
class MLIRContext;
1615
class ModuleOp;
1716
template <typename T>
1817
class OperationPass;
1918

20-
/// Populate the given list with patterns that convert from Linalg to Standard.
21-
void populateLinalgToStandardConversionPatterns(
22-
OwningRewritePatternList &patterns, MLIRContext *ctx);
23-
2419
/// Create a pass to convert Linalg operations to the Standard dialect.
2520
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
2621

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,9 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
502502
getIteratorTypesAttrName(), getSymbolSourceAttrName()
503503
};
504504
}
505-
StringRef getLibraryCallName() {
506-
return library_call().hasValue() ? library_call().getValue() : "";
505+
std::string getLibraryCallName() {
506+
return library_call().hasValue() ?
507+
library_call()->str() : "op_has_no_registered_library_name";
507508
}
508509
llvm::Optional<unsigned> getSymbolSource() {
509510
auto ss = symbol_source();

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
594594
llvm::all_of(this->getOperation()->getResults(), isTensorType);
595595
}]
596596
>,
597+
InterfaceMethod<
598+
/*desc=*/[{
599+
Return the name registered for this op when lowering to an external
600+
library call.
601+
}],
602+
/*retTy=*/"std::string",
603+
/*methodName=*/"getLibraryCallName",
604+
/*args=*/(ins),
605+
/*methodBody=*/"",
606+
/*defaultImplementation=*/[{
607+
return $_op.getLibraryCallName();
608+
}]
609+
>,
597610

598611
//===------------------------------------------------------------------===//
599612
// Other static interface methods.

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,7 @@ struct LinalgTilingOptions {
347347
/// values must not fold away when tiling. Otherwise, use a more robust
348348
/// `tileSizeComputationFunction`.
349349
LinalgTilingOptions &setTileSizes(SmallVector<Value, 4> ts) {
350-
tileSizeComputationFunction = [=](OpBuilder &, Operation *) {
351-
return ts;
352-
};
350+
tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
353351
return *this;
354352
}
355353
/// Convenience function to set the `tileSizeComputationFunction` to a
@@ -749,6 +747,56 @@ class ConvOpVectorization : public OpRewritePattern<ConvOp> {
749747
PatternRewriter &rewriter) const override;
750748
};
751749

750+
//===----------------------------------------------------------------------===//
751+
// Patterns to convert a LinalgOp to std.call @external library implementation.
752+
//===----------------------------------------------------------------------===//
753+
// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
754+
// function. The implementation of the function can be either in the same module
755+
// or in an externally linked library.
756+
// This is a generic entry point for all LinalgOp, except for CopyOp and
757+
// IndexedGenericOp, for which omre specialized patterns are provided.
758+
class LinalgOpToLibraryCallRewrite : public RewritePattern {
759+
public:
760+
LinalgOpToLibraryCallRewrite()
761+
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
762+
763+
LogicalResult matchAndRewrite(Operation *op,
764+
PatternRewriter &rewriter) const override;
765+
};
766+
767+
/// Rewrite pattern specialization for CopyOp, kicks in when both input and
768+
/// output permutations are left unspecified or are the identity.
769+
class CopyOpToLibraryCallRewrite : public OpRewritePattern<CopyOp> {
770+
public:
771+
using OpRewritePattern<CopyOp>::OpRewritePattern;
772+
LogicalResult matchAndRewrite(CopyOp op,
773+
PatternRewriter &rewriter) const override;
774+
};
775+
776+
/// Rewrite CopyOp with permutations into a sequence of TransposeOp and
777+
/// permutation-free CopyOp. This interplays with TransposeOpConversion and
778+
/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
779+
class CopyTransposeRewrite : public OpRewritePattern<CopyOp> {
780+
public:
781+
using OpRewritePattern<CopyOp>::OpRewritePattern;
782+
LogicalResult matchAndRewrite(CopyOp op,
783+
PatternRewriter &rewriter) const override;
784+
};
785+
786+
/// Conversion pattern specialization for IndexedGenericOp, has special handling
787+
/// for the extra index operands.
788+
class IndexedGenericOpToLibraryCallRewrite
789+
: public OpRewritePattern<IndexedGenericOp> {
790+
public:
791+
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
792+
LogicalResult matchAndRewrite(IndexedGenericOp op,
793+
PatternRewriter &rewriter) const override;
794+
};
795+
796+
/// Populate the given list with patterns that convert from Linalg to Standard.
797+
void populateLinalgToStandardConversionPatterns(
798+
OwningRewritePatternList &patterns, MLIRContext *ctx);
799+
752800
//===----------------------------------------------------------------------===//
753801
// Support for staged pattern application.
754802
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)