Skip to content

Commit 21caba5

Browse files
committed
[MLIR] Lower GenericAtomicRMWOp to llvm.cmpxchg.
Summary: Lowering is pretty much a copy of AtomicRMWOp -> llvm.cmpxchg pattern. Differential Revision: https://reviews.llvm.org/D78647
1 parent 47ef09e commit 21caba5

File tree

3 files changed

+127
-0
lines changed

3 files changed

+127
-0
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,9 @@ def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
539539
Value getCurrentValue() {
540540
return body().front().getArgument(0);
541541
}
542+
MemRefType getMemRefType() {
543+
return memref().getType().cast<MemRefType>();
544+
}
542545
}];
543546
}
544547

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1818
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1919
#include "mlir/IR/Attributes.h"
20+
#include "mlir/IR/BlockAndValueMapping.h"
2021
#include "mlir/IR/Builders.h"
2122
#include "mlir/IR/MLIRContext.h"
2223
#include "mlir/IR/Module.h"
@@ -2746,6 +2747,104 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
27462747
}
27472748
};
27482749

2750+
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
2751+
/// retried until it succeeds in atomically storing a new value into memory.
2752+
///
2753+
/// +---------------------------------+
2754+
/// | <code before the AtomicRMWOp> |
2755+
/// | <compute initial %loaded> |
2756+
/// | br loop(%loaded) |
2757+
/// +---------------------------------+
2758+
/// |
2759+
/// -------| |
2760+
/// | v v
2761+
/// | +--------------------------------+
2762+
/// | | loop(%loaded): |
2763+
/// | | <body contents> |
2764+
/// | | %pair = cmpxchg |
2765+
/// | | %ok = %pair[0] |
2766+
/// | | %new = %pair[1] |
2767+
/// | | cond_br %ok, end, loop(%new) |
2768+
/// | +--------------------------------+
2769+
/// | | |
2770+
/// |----------- |
2771+
/// v
2772+
/// +--------------------------------+
2773+
/// | end: |
2774+
/// | <code after the AtomicRMWOp> |
2775+
/// +--------------------------------+
2776+
///
2777+
struct GenericAtomicRMWOpLowering
2778+
: public LoadStoreOpLowering<GenericAtomicRMWOp> {
2779+
using Base::Base;
2780+
2781+
LogicalResult
2782+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
2783+
ConversionPatternRewriter &rewriter) const override {
2784+
auto atomicOp = cast<GenericAtomicRMWOp>(op);
2785+
2786+
auto loc = op->getLoc();
2787+
OperandAdaptor<GenericAtomicRMWOp> adaptor(operands);
2788+
LLVM::LLVMType valueType =
2789+
typeConverter.convertType(atomicOp.getResult().getType())
2790+
.cast<LLVM::LLVMType>();
2791+
2792+
// Split the block into initial, loop, and ending parts.
2793+
auto *initBlock = rewriter.getInsertionBlock();
2794+
auto initPosition = rewriter.getInsertionPoint();
2795+
auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
2796+
auto loopArgument = loopBlock->addArgument(valueType);
2797+
auto loopPosition = rewriter.getInsertionPoint();
2798+
auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
2799+
2800+
// Compute the loaded value and branch to the loop block.
2801+
rewriter.setInsertionPointToEnd(initBlock);
2802+
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
2803+
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
2804+
adaptor.indices(), rewriter, getModule());
2805+
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
2806+
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
2807+
2808+
// Prepare the body of the loop block.
2809+
rewriter.setInsertionPointToStart(loopBlock);
2810+
auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
2811+
2812+
// Clone the GenericAtomicRMWOp region and extract the result.
2813+
BlockAndValueMapping mapping;
2814+
mapping.map(atomicOp.getCurrentValue(), loopArgument);
2815+
Block &entryBlock = atomicOp.body().front();
2816+
for (auto &nestedOp : entryBlock.without_terminator()) {
2817+
Operation *clone = rewriter.clone(nestedOp, mapping);
2818+
mapping.map(nestedOp.getResults(), clone->getResults());
2819+
}
2820+
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
2821+
2822+
// Prepare the epilog of the loop block.
2823+
rewriter.setInsertionPointToEnd(loopBlock);
2824+
// Append the cmpxchg op to the end of the loop block.
2825+
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
2826+
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
2827+
auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
2828+
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
2829+
loc, pairType, dataPtr, loopArgument, result, successOrdering,
2830+
failureOrdering);
2831+
// Extract the %new_loaded and %ok values from the pair.
2832+
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
2833+
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
2834+
Value ok = rewriter.create<LLVM::ExtractValueOp>(
2835+
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
2836+
2837+
// Conditionally branch to the end or back to the loop depending on %ok.
2838+
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
2839+
loopBlock, newLoaded);
2840+
2841+
// The 'result' of the atomic_rmw op is the newly loaded value.
2842+
rewriter.replaceOp(op, {newLoaded});
2843+
2844+
return success();
2845+
}
2846+
};
2847+
27492848
} // namespace
27502849

27512850
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
@@ -2775,6 +2874,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
27752874
DivFOpLowering,
27762875
ExpOpLowering,
27772876
Exp2OpLowering,
2877+
GenericAtomicRMWOpLowering,
27782878
LogOpLowering,
27792879
Log10OpLowering,
27802880
Log2OpLowering,

mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,30 @@ func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
10291029

10301030
// -----
10311031

1032+
// CHECK-LABEL: func @generic_atomic_rmw
1033+
// CHECK32-LABEL: func @generic_atomic_rmw
1034+
func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
1035+
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
1036+
^bb0(%old_value : f32):
1037+
%c1 = constant 1.0 : f32
1038+
atomic_yield %c1 : f32
1039+
}
1040+
// CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm<"float*">
1041+
// CHECK-NEXT: llvm.br ^bb1([[init]] : !llvm.float)
1042+
// CHECK-NEXT: ^bb1([[loaded:%.*]]: !llvm.float):
1043+
// CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1.000000e+00 : f32)
1044+
// CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]]
1045+
// CHECK-SAME: acq_rel monotonic : !llvm.float
1046+
// CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
1047+
// CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
1048+
// CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : !llvm.float)
1049+
// CHECK-NEXT: ^bb2:
1050+
return %x : f32
1051+
// CHECK-NEXT: llvm.return [[new]]
1052+
}
1053+
1054+
// -----
1055+
10321056
// CHECK-LABEL: func @assume_alignment
10331057
func @assume_alignment(%0 : memref<4x4xf16>) {
10341058
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">

0 commit comments

Comments
 (0)