Skip to content

[mlir] Remove convertible identity restriction for memref.atomic_rmw … #72262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/MathExtras.h"
Expand Down Expand Up @@ -1562,12 +1563,14 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
LogicalResult
matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(match(atomicOp)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually dont understand this check..... This looks OK to me, but what was stopped by this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks okay to me. I don't know who should be involved in the review even I looked at git blame. It looks like very very few people update the pattern and op...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually dont understand this check..... This looks OK to me, but what was stopped by this change?

It fails when the memref type has an offset because it checks that the memref type has an identity map. But the pattern can handle offsets, so it should be fine to remove the check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this only works for strided memref types. Can you check that getStridesAndOffset on the source type returns success, i.e. return failure when that method fails. I looked into the implementation of getStridedElementPtr and it asserts if thats not the case, so letting that through would lead to an assert

return failure();
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
return failure();
auto memRefType = atomicOp.getMemRefType();
SmallVector<int64_t> strides;
int64_t offset;
if (failed(getStridesAndOffset(memRefType, strides, offset)))
return failure();
auto dataPtr =
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,24 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv

// -----

func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>, %ival : i32, %i : index) {
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32, strided<[1], offset: 5>>) -> i32
return
}
// CHECK-LABEL: func @atomic_rmw_with_offset
// CHECK-SAME: %[[ARG0:.+]]: memref<10xi32, strided<[1], offset: 5>>
// CHECK-SAME: %[[ARG1:.+]]: i32
// CHECK-SAME: %[[ARG2:.+]]: index
// CHECK: %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64
// CHECK: %[[OFFSET_PTR:.+]] = llvm.getelementptr %[[BASE_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
// CHECK: %[[PTR:.+]] = llvm.getelementptr %[[OFFSET_PTR]][%[[INDEX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
// CHECK: llvm.atomicrmw _and %[[PTR]], %[[ARG1]] acq_rel

// -----

// CHECK-LABEL: func @generic_atomic_rmw
func.func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
%x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> {
Expand Down