1
- // ===- StdExpandDivs .cpp - Code to prepare Std for lowering Divs to LLVM -===//
1
+ // ===- ExpandDivs .cpp - Expansion patterns for MemRef operations --------- -===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
- //
9
- // This file Std transformations to expand Divs operation to help for the
10
- // lowering to LLVM. Currently implemented transformations are Ceil and Floor
11
- // for Signed Integers.
12
- //
13
- // ===----------------------------------------------------------------------===//
14
8
15
9
#include " mlir/Dialect/MemRef/Transforms/Passes.h"
16
10
@@ -33,44 +27,6 @@ using namespace mlir;
33
27
34
28
namespace {
35
29
36
- // / Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
37
- // / AtomicRMWOpLowering pattern, such as minimum and maximum operations for
38
- // / floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
39
- // / code.
40
- // /
41
- // / %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
42
- // /
43
- // / will be lowered to
44
- // /
45
- // / %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
46
- // / ^bb0(%current: f32):
47
- // / %1 = arith.maximumf %current, %fval : f32
48
- // / memref.atomic_yield %1 : f32
49
- // / }
50
- struct AtomicRMWOpConverter : public OpRewritePattern <memref::AtomicRMWOp> {
51
- public:
52
- using OpRewritePattern::OpRewritePattern;
53
-
54
- LogicalResult matchAndRewrite (memref::AtomicRMWOp op,
55
- PatternRewriter &rewriter) const final {
56
- auto loc = op.getLoc ();
57
- auto genericOp = rewriter.create <memref::GenericAtomicRMWOp>(
58
- loc, op.getMemref (), op.getIndices ());
59
- OpBuilder bodyBuilder =
60
- OpBuilder::atBlockEnd (genericOp.getBody (), rewriter.getListener ());
61
-
62
- Value lhs = genericOp.getCurrentValue ();
63
- Value rhs = op.getValue ();
64
-
65
- Value arithOp =
66
- mlir::arith::getReductionOp (op.getKind (), bodyBuilder, loc, lhs, rhs);
67
- bodyBuilder.create <memref::AtomicYieldOp>(loc, arithOp);
68
-
69
- rewriter.replaceOp (op, genericOp.getResult ());
70
- return success ();
71
- }
72
- };
73
-
74
30
// / Converts `memref.reshape` that has a target shape of a statically-known
75
31
// / size to `memref.reinterpret_cast`.
76
32
struct MemRefReshapeOpConverter : public OpRewritePattern <memref::ReshapeOp> {
@@ -139,13 +95,6 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
139
95
ConversionTarget target (ctx);
140
96
141
97
target.addLegalDialect <arith::ArithDialect, memref::MemRefDialect>();
142
- target.addDynamicallyLegalOp <memref::AtomicRMWOp>(
143
- [](memref::AtomicRMWOp op) {
144
- constexpr std::array shouldBeExpandedKinds = {
145
- arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
146
- arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
147
- return !llvm::is_contained (shouldBeExpandedKinds, op.getKind ());
148
- });
149
98
target.addDynamicallyLegalOp <memref::ReshapeOp>([](memref::ReshapeOp op) {
150
99
return !cast<MemRefType>(op.getShape ().getType ()).hasStaticShape ();
151
100
});
@@ -158,6 +107,5 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
158
107
} // namespace
159
108
160
109
void mlir::memref::populateExpandOpsPatterns (RewritePatternSet &patterns) {
161
- patterns.add <AtomicRMWOpConverter, MemRefReshapeOpConverter>(
162
- patterns.getContext ());
110
+ patterns.add <MemRefReshapeOpConverter>(patterns.getContext ());
163
111
}
0 commit comments