Skip to content

Commit ce6ef99

Browse files
authored
[mlir] Remove convertible identity restriction for memref.atomic_rmw to LLVM (#72262)
memref.atomic_rmw will fail to convert for memref types that have an offset because they do not have identity maps. This restriction is overly conservative, so this changes the restriction to only strided memref types.
1 parent 6c2bde9 commit ce6ef99

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2323
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
2424
#include "mlir/IR/AffineMap.h"
25+
#include "mlir/IR/BuiltinTypes.h"
2526
#include "mlir/IR/IRMapping.h"
2627
#include "mlir/Pass/Pass.h"
2728
#include "mlir/Support/MathExtras.h"
@@ -1562,12 +1563,14 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
15621563
LogicalResult
15631564
matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
15641565
ConversionPatternRewriter &rewriter) const override {
1565-
if (failed(match(atomicOp)))
1566-
return failure();
15671566
auto maybeKind = matchSimpleAtomicOp(atomicOp);
15681567
if (!maybeKind)
15691568
return failure();
15701569
auto memRefType = atomicOp.getMemRefType();
1570+
SmallVector<int64_t> strides;
1571+
int64_t offset;
1572+
if (failed(getStridesAndOffset(memRefType, strides, offset)))
1573+
return failure();
15711574
auto dataPtr =
15721575
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
15731576
adaptor.getIndices(), rewriter);

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,24 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
400400

401401
// -----
402402

403+
func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>, %ival : i32, %i : index) {
404+
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32, strided<[1], offset: 5>>) -> i32
405+
return
406+
}
407+
// CHECK-LABEL: func @atomic_rmw_with_offset
408+
// CHECK-SAME: %[[ARG0:.+]]: memref<10xi32, strided<[1], offset: 5>>
409+
// CHECK-SAME: %[[ARG1:.+]]: i32
410+
// CHECK-SAME: %[[ARG2:.+]]: index
411+
// CHECK: %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
412+
// CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
413+
// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
414+
// CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64
415+
// CHECK: %[[OFFSET_PTR:.+]] = llvm.getelementptr %[[BASE_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
416+
// CHECK: %[[PTR:.+]] = llvm.getelementptr %[[OFFSET_PTR]][%[[INDEX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
417+
// CHECK: llvm.atomicrmw _and %[[PTR]], %[[ARG1]] acq_rel
418+
419+
// -----
420+
403421
// CHECK-LABEL: func @generic_atomic_rmw
404422
func.func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
405423
%x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> {

0 commit comments

Comments
 (0)