|
17 | 17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
18 | 18 | #include "mlir/Dialect/StandardOps/IR/Ops.h"
|
19 | 19 | #include "mlir/IR/Attributes.h"
|
| 20 | +#include "mlir/IR/BlockAndValueMapping.h" |
20 | 21 | #include "mlir/IR/Builders.h"
|
21 | 22 | #include "mlir/IR/MLIRContext.h"
|
22 | 23 | #include "mlir/IR/Module.h"
|
@@ -2746,6 +2747,104 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
2746 | 2747 | }
|
2747 | 2748 | };
|
2748 | 2749 |
|
| 2750 | +/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be |
| 2751 | +/// retried until it succeeds in atomically storing a new value into memory. |
| 2752 | +/// |
| 2753 | +/// +---------------------------------+ |
| 2754 | +/// | <code before the AtomicRMWOp> | |
| 2755 | +/// | <compute initial %loaded> | |
| 2756 | +/// | br loop(%loaded) | |
| 2757 | +/// +---------------------------------+ |
| 2758 | +/// | |
| 2759 | +/// -------| | |
| 2760 | +/// | v v |
| 2761 | +/// | +--------------------------------+ |
| 2762 | +/// | | loop(%loaded): | |
| 2763 | +/// | | <body contents> | |
| 2764 | +/// | | %pair = cmpxchg | |
| 2765 | +/// | | %ok = %pair[0] | |
| 2766 | +/// | | %new = %pair[1] | |
| 2767 | +/// | | cond_br %ok, end, loop(%new) | |
| 2768 | +/// | +--------------------------------+ |
| 2769 | +/// | | | |
| 2770 | +/// |----------- | |
| 2771 | +/// v |
| 2772 | +/// +--------------------------------+ |
| 2773 | +/// | end: | |
| 2774 | +/// | <code after the AtomicRMWOp> | |
| 2775 | +/// +--------------------------------+ |
| 2776 | +/// |
| 2777 | +struct GenericAtomicRMWOpLowering |
| 2778 | + : public LoadStoreOpLowering<GenericAtomicRMWOp> { |
| 2779 | + using Base::Base; |
| 2780 | + |
| 2781 | + LogicalResult |
| 2782 | + matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 2783 | + ConversionPatternRewriter &rewriter) const override { |
| 2784 | + auto atomicOp = cast<GenericAtomicRMWOp>(op); |
| 2785 | + |
| 2786 | + auto loc = op->getLoc(); |
| 2787 | + OperandAdaptor<GenericAtomicRMWOp> adaptor(operands); |
| 2788 | + LLVM::LLVMType valueType = |
| 2789 | + typeConverter.convertType(atomicOp.getResult().getType()) |
| 2790 | + .cast<LLVM::LLVMType>(); |
| 2791 | + |
| 2792 | + // Split the block into initial, loop, and ending parts. |
| 2793 | + auto *initBlock = rewriter.getInsertionBlock(); |
| 2794 | + auto initPosition = rewriter.getInsertionPoint(); |
| 2795 | + auto *loopBlock = rewriter.splitBlock(initBlock, initPosition); |
| 2796 | + auto loopArgument = loopBlock->addArgument(valueType); |
| 2797 | + auto loopPosition = rewriter.getInsertionPoint(); |
| 2798 | + auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition); |
| 2799 | + |
| 2800 | + // Compute the loaded value and branch to the loop block. |
| 2801 | + rewriter.setInsertionPointToEnd(initBlock); |
| 2802 | + auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); |
| 2803 | + auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), |
| 2804 | + adaptor.indices(), rewriter, getModule()); |
| 2805 | + Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); |
| 2806 | + rewriter.create<LLVM::BrOp>(loc, init, loopBlock); |
| 2807 | + |
| 2808 | + // Prepare the body of the loop block. |
| 2809 | + rewriter.setInsertionPointToStart(loopBlock); |
| 2810 | + auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); |
| 2811 | + |
| 2812 | + // Clone the GenericAtomicRMWOp region and extract the result. |
| 2813 | + BlockAndValueMapping mapping; |
| 2814 | + mapping.map(atomicOp.getCurrentValue(), loopArgument); |
| 2815 | + Block &entryBlock = atomicOp.body().front(); |
| 2816 | + for (auto &nestedOp : entryBlock.without_terminator()) { |
| 2817 | + Operation *clone = rewriter.clone(nestedOp, mapping); |
| 2818 | + mapping.map(nestedOp.getResults(), clone->getResults()); |
| 2819 | + } |
| 2820 | + Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); |
| 2821 | + |
| 2822 | + // Prepare the epilog of the loop block. |
| 2823 | + rewriter.setInsertionPointToEnd(loopBlock); |
| 2824 | + // Append the cmpxchg op to the end of the loop block. |
| 2825 | + auto successOrdering = LLVM::AtomicOrdering::acq_rel; |
| 2826 | + auto failureOrdering = LLVM::AtomicOrdering::monotonic; |
| 2827 | + auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); |
| 2828 | + auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( |
| 2829 | + loc, pairType, dataPtr, loopArgument, result, successOrdering, |
| 2830 | + failureOrdering); |
| 2831 | + // Extract the %new_loaded and %ok values from the pair. |
| 2832 | + Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( |
| 2833 | + loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); |
| 2834 | + Value ok = rewriter.create<LLVM::ExtractValueOp>( |
| 2835 | + loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); |
| 2836 | + |
| 2837 | + // Conditionally branch to the end or back to the loop depending on %ok. |
| 2838 | + rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), |
| 2839 | + loopBlock, newLoaded); |
| 2840 | + |
| 2841 | + // The 'result' of the atomic_rmw op is the newly loaded value. |
| 2842 | + rewriter.replaceOp(op, {newLoaded}); |
| 2843 | + |
| 2844 | + return success(); |
| 2845 | + } |
| 2846 | +}; |
| 2847 | + |
2749 | 2848 | } // namespace
|
2750 | 2849 |
|
2751 | 2850 | /// Collect a set of patterns to convert from the Standard dialect to LLVM.
|
@@ -2775,6 +2874,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
|
2775 | 2874 | DivFOpLowering,
|
2776 | 2875 | ExpOpLowering,
|
2777 | 2876 | Exp2OpLowering,
|
| 2877 | + GenericAtomicRMWOpLowering, |
2778 | 2878 | LogOpLowering,
|
2779 | 2879 | Log10OpLowering,
|
2780 | 2880 | Log2OpLowering,
|
|
0 commit comments