Skip to content

Commit 9c4184b

Browse files
author
Ferdinand Lemaire
committed
Merge conflicts
2 parents af893da + f799862 commit 9c4184b

File tree

23 files changed

+1620
-125
lines changed

23 files changed

+1620
-125
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue,
5353
/// This function returns `failure()` in case of unsupported casts. E.g., casts
5454
/// with differing element types or memory spaces.
5555
FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
56-
MemRefType type);
56+
MemRefType type,
57+
const BufferizationOptions &options);
5758

5859
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
5960
/// to_memref op are different, a memref.cast is needed.
6061
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
61-
ToMemrefOp toMemref);
62+
ToMemrefOp toMemref,
63+
const BufferizationOptions &options);
6264

6365
/// Add the canonicalization patterns for bufferization.dealloc to the given
6466
/// pattern set to make them available to other passes (such as

mlir/include/mlir/Dialect/Func/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,15 @@ def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination",
5151
let constructor = "mlir::func::createDuplicateFunctionEliminationPass()";
5252
}
5353

54+
def AnnotateFunctionType: Pass<"annotate-function-type", "func::FuncOp"> {
55+
let summary = "Annotate the function type as type attributes";
56+
let description = [{
57+
Annotates all the inputs and outputs of func.func operators with a type
58+
attribute. The type attribute mirrors the actual type of the inputs/outputs.
59+
60+
This pass can be used to trace back the original types of func.func
61+
operators in case they need to be modified.
62+
}];
63+
}
64+
5465
#endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD

mlir/include/mlir/Dialect/PDL/IR/Builtins.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,19 @@ namespace pdl {
2626
void registerBuiltins(PDLPatternModule &pdlPattern);
2727

2828
namespace builtin {
29+
enum class BinaryOpKind {
30+
add,
31+
sub,
32+
mul,
33+
div,
34+
mod,
35+
};
36+
37+
enum class UnaryOpKind {
38+
log2,
39+
exp2,
40+
};
41+
2942
LogicalResult createDictionaryAttr(PatternRewriter &rewriter,
3043
PDLResultList &results,
3144
ArrayRef<PDLValue> args);
@@ -35,8 +48,27 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
3548
Attribute createArrayAttr(PatternRewriter &rewriter);
3649
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
3750
Attribute element);
51+
template <BinaryOpKind T>
52+
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
53+
llvm::ArrayRef<PDLValue> args);
54+
LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results,
55+
llvm::ArrayRef<PDLValue> args);
56+
LogicalResult div(PatternRewriter &rewriter, PDLResultList &results,
57+
llvm::ArrayRef<PDLValue> args);
58+
LogicalResult mod(PatternRewriter &rewriter, PDLResultList &results,
59+
llvm::ArrayRef<PDLValue> args);
3860
LogicalResult add(PatternRewriter &rewriter, PDLResultList &results,
3961
llvm::ArrayRef<PDLValue> args);
62+
LogicalResult sub(PatternRewriter &rewriter, PDLResultList &results,
63+
llvm::ArrayRef<PDLValue> args);
64+
LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
65+
llvm::ArrayRef<PDLValue> args);
66+
LogicalResult exp2(PatternRewriter &rewriter, PDLResultList &results,
67+
llvm::ArrayRef<PDLValue> args);
68+
69+
template <BinaryOpKind T>
70+
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
71+
llvm::ArrayRef<PDLValue> args);
4072
} // namespace builtin
4173
} // namespace pdl
4274
} // namespace mlir

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ using namespace mlir::bufferization;
2323
// Helper functions
2424
//===----------------------------------------------------------------------===//
2525

26-
FailureOr<Value>
27-
mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
28-
MemRefType destType) {
26+
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
27+
OpBuilder &b, Value value, MemRefType destType,
28+
const BufferizationOptions &options) {
2929
auto srcType = llvm::cast<MemRefType>(value.getType());
3030

3131
// Element type, rank and memory space must match.
@@ -73,18 +73,23 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
7373
Value size = b.create<memref::DimOp>(loc, value, i);
7474
dynamicOperands.push_back(size);
7575
}
76-
// TODO: Use alloc/memcpy callback from BufferizationOptions if called via
77-
// BufferizableOpInterface impl of ToMemrefOp.
78-
Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
79-
b.create<memref::CopyOp>(loc, value, copy);
76+
77+
FailureOr<Value> copy =
78+
options.createAlloc(b, loc, destType, dynamicOperands);
79+
if (failed(copy)) {
80+
return failure();
81+
}
82+
if (failed(options.createMemCpy(b, loc, value, *copy))) {
83+
return failure();
84+
}
8085
return copy;
8186
}
8287

8388
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
8489
/// to_memref op are different, a memref.cast is needed.
85-
LogicalResult
86-
mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
87-
ToMemrefOp toMemref) {
90+
LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
91+
RewriterBase &rewriter, ToMemrefOp toMemref,
92+
const BufferizationOptions &options) {
8893
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
8994
if (!memrefToTensor)
9095
return failure();
@@ -105,7 +110,7 @@ mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
105110
// Ranked memref -> Ranked memref cast.
106111
if (rankedSrcType && rankedDestType) {
107112
FailureOr<Value> replacement = castOrReallocMemRefValue(
108-
rewriter, memrefToTensor.getMemref(), rankedDestType);
113+
rewriter, memrefToTensor.getMemref(), rankedDestType, options);
109114
if (failed(replacement))
110115
return failure();
111116

@@ -802,7 +807,7 @@ struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
802807

803808
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
804809
PatternRewriter &rewriter) const final {
805-
return foldToMemrefToTensorPair(rewriter, toMemref);
810+
return foldToMemrefToTensorPair(rewriter, toMemref, {});
806811
}
807812
};
808813

@@ -850,7 +855,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
850855
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
851856
const BufferizationOptions &options) {
852857
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
853-
(void)foldToMemrefToTensorPair(rewriter, *this);
858+
(void)foldToMemrefToTensorPair(rewriter, *this, options);
854859
// Note: The return value of `bufferize` indicates whether there was an error
855860
// or not. (And not whether the pattern matched or not.)
856861
return success();

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
7575
if (!rankedDestType)
7676
return nullptr;
7777
FailureOr<Value> replacement =
78-
castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
78+
castOrReallocMemRefValue(builder, inputs[0], rankedDestType, {});
7979
if (failed(replacement))
8080
return nullptr;
8181
return *replacement;
@@ -509,8 +509,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
509509
// Fold all to_memref(to_tensor(x)) pairs.
510510
for (Operation *op : toMemrefOps) {
511511
rewriter.setInsertionPoint(op);
512-
(void)bufferization::foldToMemrefToTensorPair(rewriter,
513-
cast<ToMemrefOp>(op));
512+
(void)bufferization::foldToMemrefToTensorPair(
513+
rewriter, cast<ToMemrefOp>(op), options);
514514
}
515515

516516
// Remove all dead to_tensor ops.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===- AnnotateInputTypes.cpp - Type attribute annotation for func ops ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass that creates type attributes for func parameters,
10+
// that mirror the actual type. This is useful when the func op input types
11+
// might change.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Func/Transforms/Passes.h"
16+
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/IR/BuiltinAttributes.h"
19+
#include "mlir/Pass/Pass.h"
20+
21+
using namespace mlir;
22+
23+
namespace mlir::func {
24+
#define GEN_PASS_DEF_ANNOTATEFUNCTIONTYPE
25+
#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
26+
} // namespace mlir::func
27+
28+
namespace {
29+
struct AnnotateFunctionTypePass
30+
: public mlir::func::impl::AnnotateFunctionTypeBase<
31+
AnnotateFunctionTypePass> {
32+
33+
void runOnOperation() override {
34+
func::FuncOp func = getOperation();
35+
auto inputs = func.getArgumentTypes();
36+
auto results = func.getResultTypes();
37+
38+
for (const auto [argNum, type] : llvm::enumerate(inputs)) {
39+
func.setArgAttr(argNum, "func.orig_type", TypeAttr::get(type));
40+
}
41+
42+
for (const auto [resultNum, type] : llvm::enumerate(results)) {
43+
func.setResultAttr(resultNum, "func.orig_type", TypeAttr::get(type));
44+
}
45+
}
46+
};
47+
} // namespace

mlir/lib/Dialect/Func/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRFuncTransforms
2+
AnnotateFunctionType.cpp
23
DecomposeCallGraphTypes.cpp
34
DuplicateFunctionElimination.cpp
45
FuncBufferize.cpp

0 commit comments

Comments
 (0)