Skip to content

[mlir][MemRef] Changed AssumeAlignment into a Pure ViewLikeOp #139521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,22 +142,37 @@ class AllocLikeOp<string mnemonic,
// AssumeAlignmentOp
//===----------------------------------------------------------------------===//

def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
ViewLikeOpInterface,
SameOperandsAndResultType
]> {
let summary =
"assertion that gives alignment information to the input memref";
"assumption that gives alignment information to the input memref";
let description = [{
The `assume_alignment` operation takes a memref and an integer of alignment
value, and internally annotates the buffer with the given alignment. If
the buffer isn't aligned to the given alignment, the behavior is undefined.
The `assume_alignment` operation takes a memref and an integer alignment
value. It returns a new SSA value of the same memref type, but associated
with the assumption that the underlying buffer is aligned to the given
alignment.

This operation doesn't affect the semantics of a correct program. It's for
optimization only, and the optimization is best-effort.
If the buffer isn't aligned to the given alignment, its result is poison.
This operation doesn't affect the semantics of a program where the
alignment assumption holds true. It is intended for optimization purposes,
allowing the compiler to generate more efficient code based on the
alignment assumption. The optimization is best-effort.
}];
let arguments = (ins AnyMemRef:$memref,
ConfinedAttr<I32Attr, [IntPositive]>:$alignment);
let results = (outs);
let results = (outs AnyMemRef:$result);

let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
let extraClassDeclaration = [{
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }

Value getViewSource() { return getMemref(); }
}];

let hasVerifier = 1;
}

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,7 @@ struct AssumeAlignmentOpLowering
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
alignmentConst);

rewriter.eraseOp(op);
rewriter.replaceOp(op, memref);
return success();
}
};
Expand Down
11 changes: 0 additions & 11 deletions mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ using namespace mlir::gpu;
// The functions below provide interface-like verification, but are too specific
// to barrier elimination to become interfaces.

/// Implement the MemoryEffectsOpInterface in the suitable way.
static bool isKnownNoEffectsOpWithoutInterface(Operation *op) {
// memref::AssumeAlignment is conceptually pure, but marking it as such would
// make DCE immediately remove it.
return isa<memref::AssumeAlignmentOp>(op);
}

/// Returns `true` if the op is defines the parallel region that is subject to
/// barrier synchronization.
static bool isParallelRegionBoundary(Operation *op) {
Expand Down Expand Up @@ -101,10 +94,6 @@ collectEffects(Operation *op,
if (ignoreBarriers && isa<BarrierOp>(op))
return true;

// Skip over ops that we know have no effects.
if (isKnownNoEffectsOpWithoutInterface(op))
return true;

// Collect effect instances the operation. Note that the implementation of
// getEffects erases all effect instances that have the type other than the
// template parameter so we collect them first in a local buffer and then
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,11 @@ LogicalResult AssumeAlignmentOp::verify() {
return success();
}

void AssumeAlignmentOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "assume_align");
}

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ struct ConvertMemRefAssumeAlignment final
}

rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
op, adaptor.getMemref(), adaptor.getAlignmentAttr());
op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
return success();
}
};
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,35 @@ struct ExtractStridedMetadataOpGetGlobalFolder
}
};

/// Pattern to replace `extract_strided_metadata(assume_alignment)`
///
/// With
/// \verbatim
/// extract_strided_metadata(memref)
/// \endverbatim
///
/// Since `assume_alignment` is a view-like op that does not modify the
/// underlying buffer, offset, sizes, or strides, extracting strided metadata
/// from its result is equivalent to extracting it from its source. This
/// canonicalization removes the unnecessary indirection.
struct ExtractStridedMetadataOpAssumeAlignmentFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
public:
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;

LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
auto assumeAlignmentOp =
op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
if (!assumeAlignmentOp)
return failure();

rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(
op, assumeAlignmentOp.getViewSource());
return success();
}
};

/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
/// source of the ViewLikeOp.
class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
Expand Down Expand Up @@ -1185,6 +1214,7 @@ void memref::populateExpandStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand All @@ -1201,6 +1231,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ func.func @load_and_assume(
%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>,
%i0: index, %i1: index)
-> f32 {
memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
%arg0_align = memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%2 = memref.load %arg0_align[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
func.return %2 : f32
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,11 @@ func.func @func_with_assert(%arg0: index, %arg1: index) {
%0 = arith.cmpi slt, %arg0, %arg1 : index
cf.assert %0, "%arg0 must be less than %arg1"
return
}

// CHECK-LABEL: func @func_with_assume_alignment(
// CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
%0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
return
}
28 changes: 14 additions & 14 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ func.func @memref_load_i4(%arg0: index) -> i4 {

func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
%0 = memref.alloc() : memref<3x125xi4>
memref.assume_alignment %0, 64 : memref<3x125xi4>
%1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
%align0 =memref.assume_alignment %0, 64 : memref<3x125xi4>
%1 = memref.load %align0[%arg0,%arg1] : memref<3x125xi4>
return %1 : i4
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
Expand All @@ -73,9 +73,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
// CHECK: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
Expand All @@ -88,9 +88,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
// CHECK32: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
Expand Down Expand Up @@ -350,16 +350,16 @@ func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {

func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
%0 = memref.alloc() : memref<3x125xi4>
memref.assume_alignment %0, 64 : memref<3x125xi4>
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
memref.store %arg2, %align0[%arg0,%arg1] : memref<3x125xi4>
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
// CHECK: func @memref_store_i4_rank2(
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
Expand All @@ -369,16 +369,16 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: return

// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
// CHECK32: func @memref_store_i4_rank2(
// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
Expand All @@ -388,8 +388,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: return

// -----
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ func.func @invalid_memref_cast() {
// alignment is not power of 2.
func.func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{alignment must be power of 2}}
memref.assume_alignment %0, 12 : memref<4x4xf16>
%1 = memref.assume_alignment %0, 12 : memref<4x4xf16>
return
}

Expand All @@ -887,7 +887,7 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
// 0 alignment value.
func.func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
memref.assume_alignment %0, 0 : memref<4x4xf16>
%1 = memref.assume_alignment %0, 0 : memref<4x4xf16>
return
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/MemRef/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
func.func @assume_alignment(%0: memref<4x4xf16>) {
// CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
memref.assume_alignment %0, 16 : memref<4x4xf16>
%1 = memref.assume_alignment %0, 16 : memref<4x4xf16>
return
}

Expand Down
Loading