Skip to content

Commit 8a469da

Browse files
authored
[mlir] remove unnecessary atomic_rmw expansions (#144515)
The expansion of `memref.atomic_rmw` into a `memref.generic_atomic_rmw` for floating-point min/max operations is no longer necessary as those are now supported by the LLVM dialect and LLVM IR. Furthermore, combining this expansion with direct lowering of `generic_atomic_rmw` could leads to invalid LLVM dialect IR with `cmpxchg` operating on floating-point values that it does not support.
1 parent 66d6964 commit 8a469da

File tree

2 files changed

+5
-89
lines changed

2 files changed

+5
-89
lines changed

mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
1-
//===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===//
1+
//===- ExpandDivs.cpp - Expansion patterns for MemRef operations ----------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
//
9-
// This file Std transformations to expand Divs operation to help for the
10-
// lowering to LLVM. Currently implemented transformations are Ceil and Floor
11-
// for Signed Integers.
12-
//
13-
//===----------------------------------------------------------------------===//
148

159
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1610

@@ -33,44 +27,6 @@ using namespace mlir;
3327

3428
namespace {
3529

36-
/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
37-
/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
38-
/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
39-
/// code.
40-
///
41-
/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
42-
///
43-
/// will be lowered to
44-
///
45-
/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
46-
/// ^bb0(%current: f32):
47-
/// %1 = arith.maximumf %current, %fval : f32
48-
/// memref.atomic_yield %1 : f32
49-
/// }
50-
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
51-
public:
52-
using OpRewritePattern::OpRewritePattern;
53-
54-
LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
55-
PatternRewriter &rewriter) const final {
56-
auto loc = op.getLoc();
57-
auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
58-
loc, op.getMemref(), op.getIndices());
59-
OpBuilder bodyBuilder =
60-
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
61-
62-
Value lhs = genericOp.getCurrentValue();
63-
Value rhs = op.getValue();
64-
65-
Value arithOp =
66-
mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
67-
bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
68-
69-
rewriter.replaceOp(op, genericOp.getResult());
70-
return success();
71-
}
72-
};
73-
7430
/// Converts `memref.reshape` that has a target shape of a statically-known
7531
/// size to `memref.reinterpret_cast`.
7632
struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
@@ -139,13 +95,6 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
13995
ConversionTarget target(ctx);
14096

14197
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
142-
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
143-
[](memref::AtomicRMWOp op) {
144-
constexpr std::array shouldBeExpandedKinds = {
145-
arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
146-
arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
147-
return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
148-
});
14998
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
15099
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
151100
});
@@ -158,6 +107,5 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
158107
} // namespace
159108

160109
void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
161-
patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
162-
patterns.getContext());
110+
patterns.add<MemRefReshapeOpConverter>(patterns.getContext());
163111
}

mlir/test/Dialect/MemRef/expand-ops.mlir

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,10 @@
11
// RUN: mlir-opt -memref-expand %s -split-input-file | FileCheck %s
22

3-
// CHECK-LABEL: func @atomic_rmw_to_generic
4-
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
5-
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
6-
%a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
7-
%b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
8-
%c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
9-
%d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
10-
return %a : f32
11-
}
12-
// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
13-
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
14-
// CHECK: [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32
15-
// CHECK: memref.atomic_yield [[MAXIMUM]] : f32
16-
// CHECK: }
17-
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
18-
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
19-
// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
20-
// CHECK: memref.atomic_yield [[MINIMUM]] : f32
21-
// CHECK: }
22-
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
23-
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
24-
// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
25-
// CHECK: memref.atomic_yield [[MAXNUM]] : f32
26-
// CHECK: }
27-
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
28-
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
29-
// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
30-
// CHECK: memref.atomic_yield [[MINNUM]] : f32
31-
// CHECK: }
32-
// CHECK: return [[RESULT]] : f32
33-
34-
// -----
35-
363
// CHECK-LABEL: func @atomic_rmw_no_conversion
37-
func.func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
4+
func.func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> (f32, f32) {
385
%x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
39-
return %x : f32
6+
%y = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
7+
return %x, %y : f32, f32
408
}
419
// CHECK-NOT: generic_atomic_rmw
4210

0 commit comments

Comments
 (0)