Skip to content

Commit f1b9720

Browse files
author
Nicolas Vasilache
committed
[mlir][Linalg] Start a LinalgToStandard pass and move conversion to library calls.
This revision starts decoupling the include the kitchen sink behavior of Linalg to LLVM lowering by inserting a -convert-linalg-to-std pass. The lowering of linalg ops to function calls was previously lowering to memref descriptors by having both linalg -> std and std -> LLVM patterns in the same rewrite. When separating this step, a new issue occurred: the layout is automatically type-erased by this process. This revision therefore introduces memref casts to perform these type erasures explicitly. To connect everything end-to-end, the LLVM lowering of MemRefCastOp is relaxed because it is artificially more restricted than the op semantics. The op semantics already guarantee that source and target MemRefTypes are cast-compatible. An invalid lowering test now becomes valid and is removed. Differential Revision: https://reviews.llvm.org/D79468
1 parent 940d949 commit f1b9720

File tree

15 files changed

+489
-365
lines changed

15 files changed

+489
-365
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- LinalgToStandard.h - Utils to convert from the linalg dialect ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
10+
#define MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
11+
12+
#include "mlir/Transforms/DialectConversion.h"
13+
14+
namespace mlir {
15+
class MLIRContext;
16+
class ModuleOp;
17+
template <typename T>
18+
class OperationPass;
19+
20+
/// Populate the given list with patterns that convert from Linalg to Standard.
21+
void populateLinalgToStandardConversionPatterns(
22+
OwningRewritePatternList &patterns, MLIRContext *ctx);
23+
24+
/// Create a pass to convert Linalg operations to the Standard dialect.
25+
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
26+
27+
} // namespace mlir
28+
29+
#endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_

mlir/include/mlir/Conversion/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ def ConvertLinalgToLLVM : Pass<"convert-linalg-to-llvm", "ModuleOp"> {
141141
let constructor = "mlir::createConvertLinalgToLLVMPass()";
142142
}
143143

144+
//===----------------------------------------------------------------------===//
145+
// LinalgToStandard
146+
//===----------------------------------------------------------------------===//
147+
148+
def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
149+
let summary = "Convert the operations from the linalg dialect into the "
150+
"Standard dialect";
151+
let constructor = "mlir::createConvertLinalgToStandardPass()";
152+
}
153+
144154
//===----------------------------------------------------------------------===//
145155
// LinalgToSPIRV
146156
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,12 +1745,16 @@ def MemRefCastOp : CastOp<"memref_cast"> {
17451745
The `memref_cast` operation converts a memref from one type to an equivalent
17461746
type with a compatible shape. The source and destination types are
17471747
compatible if:
1748-
a. both are ranked memref types with the same element type, affine mappings,
1749-
address space, and rank but where the individual dimensions may add or
1750-
remove constant dimensions from the memref type.
1748+
1749+
a. Both are ranked memref types with the same element type, address space,
1750+
and rank and:
1751+
1. Both have the same layout or both have compatible strided layouts.
1752+
2. The individual sizes (resp. offset and strides in the case of strided
1753+
memrefs) may convert constant dimensions to dynamic dimensions and
1754+
vice-versa.
17511755

17521756
If the cast converts any dimensions from an unknown to a known size, then it
1753-
acts as an assertion that fails at runtime of the dynamic dimensions
1757+
acts as an assertion that fails at runtime if the dynamic dimensions
17541758
disagree with resultant destination size.
17551759

17561760
Example:
@@ -1772,7 +1776,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
17721776
memref<12x4xf32, offset:?, strides: [?, ?]>
17731777
```
17741778

1775-
b. either or both memref types are unranked with the same element type, and
1779+
b. Either or both memref types are unranked with the same element type, and
17761780
address space.
17771781

17781782
Example:

mlir/include/mlir/IR/StandardTypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,10 @@ AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
723723
/// `t` with simplified layout.
724724
MemRefType canonicalizeStridedLayout(MemRefType t);
725725

726+
/// Return a version of `t` with a layout that has all dynamic offset and
727+
/// strides. This is used to erase the static layout.
728+
MemRefType eraseStridedLayout(MemRefType t);
729+
726730
/// Given MemRef `sizes` that are either static or dynamic, returns the
727731
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
728732
/// once a dynamic dimension is encountered, all canonical strides become

mlir/include/mlir/InitAllPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
2323
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
2424
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
25+
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
2526
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
2627
#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h"
2728
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_subdirectory(GPUToSPIRV)
77
add_subdirectory(GPUToVulkan)
88
add_subdirectory(LinalgToLLVM)
99
add_subdirectory(LinalgToSPIRV)
10+
add_subdirectory(LinalgToStandard)
1011
add_subdirectory(LoopsToGPU)
1112
add_subdirectory(LoopToStandard)
1213
add_subdirectory(StandardToLLVM)

mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp

Lines changed: 0 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -349,205 +349,6 @@ class YieldOpConversion : public ConvertToLLVMPattern {
349349
};
350350
} // namespace
351351

352-
template <typename LinalgOp>
353-
static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
354-
return SmallVector<Type, 4>{op->getOperandTypes()};
355-
}
356-
357-
template <>
358-
SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
359-
auto ctx = op->getContext();
360-
auto indexedGenericOp = cast<IndexedGenericOp>(op);
361-
auto numLoops = indexedGenericOp.getNumLoops();
362-
363-
SmallVector<Type, 4> result;
364-
result.reserve(numLoops + op->getNumOperands());
365-
for (unsigned i = 0; i < numLoops; ++i) {
366-
result.push_back(IndexType::get(ctx));
367-
}
368-
for (auto type : op->getOperandTypes()) {
369-
result.push_back(type);
370-
}
371-
return result;
372-
}
373-
374-
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
375-
// If the library function does not exist, insert a declaration.
376-
template <typename LinalgOp>
377-
static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
378-
PatternRewriter &rewriter) {
379-
auto linalgOp = cast<LinalgOp>(op);
380-
auto fnName = linalgOp.getLibraryCallName();
381-
if (fnName.empty()) {
382-
op->emitWarning("No library call defined for: ") << *op;
383-
return {};
384-
}
385-
386-
// fnName is a dynamic std::String, unique it via a SymbolRefAttr.
387-
FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
388-
auto module = op->getParentOfType<ModuleOp>();
389-
if (module.lookupSymbol(fnName)) {
390-
return fnNameAttr;
391-
}
392-
393-
SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
394-
assert(op->getNumResults() == 0 &&
395-
"Library call for linalg operation can be generated only for ops that "
396-
"have void return types");
397-
auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
398-
399-
OpBuilder::InsertionGuard guard(rewriter);
400-
// Insert before module terminator.
401-
rewriter.setInsertionPoint(module.getBody(),
402-
std::prev(module.getBody()->end()));
403-
FuncOp funcOp =
404-
rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
405-
ArrayRef<NamedAttribute>{});
406-
// Insert a function attribute that will trigger the emission of the
407-
// corresponding `_mlir_ciface_xxx` interface so that external libraries see
408-
// a normalized ABI. This interface is added during std to llvm conversion.
409-
funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
410-
return fnNameAttr;
411-
}
412-
413-
namespace {
414-
415-
// LinalgOpConversion<LinalgOp> creates a new call to the
416-
// `LinalgOp::getLibraryCallName()` function.
417-
// The implementation of the function can be either in the same module or in an
418-
// externally linked library.
419-
template <typename LinalgOp>
420-
class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
421-
public:
422-
using OpRewritePattern<LinalgOp>::OpRewritePattern;
423-
424-
LogicalResult matchAndRewrite(LinalgOp op,
425-
PatternRewriter &rewriter) const override {
426-
auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
427-
if (!libraryCallName)
428-
return failure();
429-
430-
rewriter.replaceOpWithNewOp<mlir::CallOp>(
431-
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
432-
return success();
433-
}
434-
};
435-
436-
/// Conversion pattern specialization for CopyOp. This kicks in when both input
437-
/// and output permutations are left unspecified or are the identity.
438-
template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
439-
public:
440-
using OpRewritePattern<CopyOp>::OpRewritePattern;
441-
442-
LogicalResult matchAndRewrite(CopyOp op,
443-
PatternRewriter &rewriter) const override {
444-
auto inputPerm = op.inputPermutation();
445-
if (inputPerm.hasValue() && !inputPerm->isIdentity())
446-
return failure();
447-
auto outputPerm = op.outputPermutation();
448-
if (outputPerm.hasValue() && !outputPerm->isIdentity())
449-
return failure();
450-
451-
auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
452-
if (!libraryCallName)
453-
return failure();
454-
455-
rewriter.replaceOpWithNewOp<mlir::CallOp>(
456-
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
457-
return success();
458-
}
459-
};
460-
461-
/// Conversion pattern specialization for IndexedGenericOp.
462-
template <>
463-
class LinalgOpConversion<IndexedGenericOp>
464-
: public OpRewritePattern<IndexedGenericOp> {
465-
public:
466-
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
467-
468-
LogicalResult matchAndRewrite(IndexedGenericOp op,
469-
PatternRewriter &rewriter) const override {
470-
auto libraryCallName =
471-
getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
472-
if (!libraryCallName)
473-
return failure();
474-
475-
// TODO(pifon, ntv): Use induction variables values instead of zeros, when
476-
// IndexedGenericOp is tiled.
477-
auto zero = rewriter.create<mlir::ConstantOp>(
478-
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
479-
auto indexedGenericOp = cast<IndexedGenericOp>(op);
480-
auto numLoops = indexedGenericOp.getNumLoops();
481-
SmallVector<Value, 4> operands;
482-
operands.reserve(numLoops + op.getNumOperands());
483-
for (unsigned i = 0; i < numLoops; ++i) {
484-
operands.push_back(zero);
485-
}
486-
for (auto operand : op.getOperands()) {
487-
operands.push_back(operand);
488-
}
489-
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
490-
ArrayRef<Type>{}, operands);
491-
return success();
492-
}
493-
};
494-
495-
/// A non-conversion rewrite pattern kicks in to convert CopyOp with
496-
/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
497-
/// This interplays together with TransposeOpConversion and
498-
/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
499-
class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
500-
public:
501-
using OpRewritePattern<CopyOp>::OpRewritePattern;
502-
503-
LogicalResult matchAndRewrite(CopyOp op,
504-
PatternRewriter &rewriter) const override {
505-
Value in = op.input(), out = op.output();
506-
507-
// If either inputPerm or outputPerm are non-identities, insert transposes.
508-
auto inputPerm = op.inputPermutation();
509-
if (inputPerm.hasValue() && !inputPerm->isIdentity())
510-
in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
511-
AffineMapAttr::get(*inputPerm));
512-
auto outputPerm = op.outputPermutation();
513-
if (outputPerm.hasValue() && !outputPerm->isIdentity())
514-
out = rewriter.create<linalg::TransposeOp>(
515-
op.getLoc(), out, AffineMapAttr::get(*outputPerm));
516-
517-
// If nothing was transposed, fail and let the conversion kick in.
518-
if (in == op.input() && out == op.output())
519-
return failure();
520-
521-
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
522-
return success();
523-
}
524-
};
525-
526-
/// Populate the given list with patterns that convert from Linalg to Standard.
527-
static void
528-
populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
529-
MLIRContext *ctx) {
530-
// TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
531-
// attribute values such as kernel striding and dilation.
532-
// clang-format off
533-
patterns.insert<
534-
CopyTransposeConversion,
535-
LinalgOpConversion<ConvOp>,
536-
LinalgOpConversion<PoolingMaxOp>,
537-
LinalgOpConversion<PoolingMinOp>,
538-
LinalgOpConversion<PoolingSumOp>,
539-
LinalgOpConversion<CopyOp>,
540-
LinalgOpConversion<DotOp>,
541-
LinalgOpConversion<FillOp>,
542-
LinalgOpConversion<GenericOp>,
543-
LinalgOpConversion<IndexedGenericOp>,
544-
LinalgOpConversion<MatmulOp>,
545-
LinalgOpConversion<MatvecOp>>(ctx);
546-
// clang-format on
547-
}
548-
549-
} // namespace
550-
551352
/// Populate the given list with patterns that convert from Linalg to LLVM.
552353
void mlir::populateLinalgToLLVMConversionPatterns(
553354
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
@@ -579,7 +380,6 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
579380
populateVectorToLoopsConversionPatterns(patterns, &getContext());
580381
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
581382
populateVectorToLLVMConversionPatterns(converter, patterns);
582-
populateLinalgToStandardConversionPatterns(patterns, &getContext());
583383
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
584384

585385
LLVMConversionTarget target(getContext());
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRLinalgToStandard
2+
LinalgToStandard.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToStandard
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
)
10+
11+
target_link_libraries(MLIRLinalgToStandard
12+
PUBLIC
13+
MLIREDSC
14+
MLIRIR
15+
MLIRLinalgOps
16+
MLIRSCF
17+
LLVMCore
18+
LLVMSupport
19+
)

0 commit comments

Comments
 (0)