Skip to content

Commit e903f6d

Browse files
author
Shay Kleiman
committed
Changed AssumeAlignment into a ViewLikeOp
Made AssumeAlignment a ViewLikeOp that returns a new SSA memref equal to its memref argument and made it have NoMemoryEffect trait. This gives it a defined memory effect that matches what it does in practice and makes it behave nicely with optimizations which won't get rid of it unless its result isn't being used.
1 parent bddbbe9 commit e903f6d

File tree

11 files changed

+73
-45
lines changed

11 files changed

+73
-45
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,22 +142,34 @@ class AllocLikeOp<string mnemonic,
142142
// AssumeAlignmentOp
143143
//===----------------------------------------------------------------------===//
144144

145-
def AssumeAlignmentOp : MemRef_Op<"assume_alignment",[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
145+
def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
146+
NoMemoryEffect,
147+
ViewLikeOpInterface,
148+
SameOperandsAndResultType
149+
]> {
146150
let summary =
147151
"assertion that gives alignment information to the input memref";
148152
let description = [{
149-
The `assume_alignment` operation takes a memref and an integer of alignment
150-
value, and internally annotates the buffer with the given alignment. If
151-
the buffer isn't aligned to the given alignment, the behavior is undefined.
153+
The `assume_alignment` operation takes a memref and an integer of alignment
154+
value. It returns a new SSA value of the same memref type, but associated
155+
with the assertion that the underlying buffer is aligned to the given
156+
alignment. If the buffer isn't aligned to the given alignment, the
157+
behavior is undefined.
152158

153-
This operation doesn't affect the semantics of a correct program. It's for
154-
optimization only, and the optimization is best-effort.
159+
This operation doesn't affect the semantics of a correct program. It's for
160+
optimization only, and the optimization is best-effort.
155161
}];
156162
let arguments = (ins AnyMemRef:$memref,
157163
ConfinedAttr<I32Attr, [IntPositive]>:$alignment);
158-
let results = (outs);
164+
let results = (outs AnyMemRef:$result);
159165

160166
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
167+
let extraClassDeclaration = [{
168+
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
169+
170+
Value getViewSource() { return getMemref(); }
171+
}];
172+
161173
let hasVerifier = 1;
162174
}
163175

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,7 @@ struct AssumeAlignmentOpLowering
432432
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
433433
rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
434434
alignmentConst);
435-
436-
rewriter.eraseOp(op);
435+
rewriter.replaceOp(op, memref);
437436
return success();
438437
}
439438
};

mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,7 @@ using namespace mlir::gpu;
4444
// The functions below provide interface-like verification, but are too specific
4545
// to barrier elimination to become interfaces.
4646

47-
/// Implement the MemoryEffectsOpInterface in the suitable way.
48-
static bool isKnownNoEffectsOpWithoutInterface(Operation *op) {
49-
// memref::AssumeAlignment is conceptually pure, but marking it as such would
50-
// make DCE immediately remove it.
51-
return isa<memref::AssumeAlignmentOp>(op);
52-
}
47+
5348

5449
/// Returns `true` if the op is defines the parallel region that is subject to
5550
/// barrier synchronization.
@@ -101,9 +96,6 @@ collectEffects(Operation *op,
10196
if (ignoreBarriers && isa<BarrierOp>(op))
10297
return true;
10398

104-
// Skip over ops that we know have no effects.
105-
if (isKnownNoEffectsOpWithoutInterface(op))
106-
return true;
10799

108100
// Collect effect instances the operation. Note that the implementation of
109101
// getEffects erases all effect instances that have the type other than the

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -526,11 +526,6 @@ LogicalResult AssumeAlignmentOp::verify() {
526526
return emitOpError("alignment must be power of 2");
527527
return success();
528528
}
529-
void AssumeAlignmentOp::getEffects(
530-
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
531-
&effects) {
532-
effects.emplace_back(MemoryEffects::Write::get());
533-
}
534529

535530
//===----------------------------------------------------------------------===//
536531
// CastOp

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ struct ConvertMemRefAssumeAlignment final
229229
}
230230

231231
rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232-
op, adaptor.getMemref(), adaptor.getAlignmentAttr());
232+
op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
233233
return success();
234234
}
235235
};

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,34 @@ struct ExtractStridedMetadataOpGetGlobalFolder
919919
}
920920
};
921921

922+
/// Pattern to replace `extract_strided_metadata(assume_alignment)`
923+
///
924+
/// With
925+
/// \verbatim
926+
/// extract_strided_metadata(memref)
927+
/// \endverbatim
928+
///
929+
/// Since `assume_alignment` is a view-like op that does not modify the
930+
/// underlying buffer, offset, sizes, or strides, extracting strided metadata
931+
/// from its result is equivalent to extracting it from its source. This
932+
/// canonicalization removes the unnecessary indirection.
933+
struct ExtractStridedMetadataOpAssumeAlignmentFolder
934+
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
935+
public:
936+
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
937+
938+
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
939+
PatternRewriter &rewriter) const override {
940+
auto assumeAlignmentOp =
941+
op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
942+
if (!assumeAlignmentOp)
943+
return failure();
944+
945+
rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(op, assumeAlignmentOp.getViewSource());
946+
return success();
947+
}
948+
};
949+
922950
/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
923951
/// source of the ViewLikeOp.
924952
class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
@@ -1185,6 +1213,7 @@ void memref::populateExpandStridedMetadataPatterns(
11851213
ExtractStridedMetadataOpSubviewFolder,
11861214
ExtractStridedMetadataOpCastFolder,
11871215
ExtractStridedMetadataOpMemorySpaceCastFolder,
1216+
ExtractStridedMetadataOpAssumeAlignmentFolder,
11881217
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
11891218
patterns.getContext());
11901219
}
@@ -1201,6 +1230,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
12011230
ExtractStridedMetadataOpReinterpretCastFolder,
12021231
ExtractStridedMetadataOpCastFolder,
12031232
ExtractStridedMetadataOpMemorySpaceCastFolder,
1233+
ExtractStridedMetadataOpAssumeAlignmentFolder,
12041234
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
12051235
patterns.getContext());
12061236
}

mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ func.func @load_and_assume(
683683
%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>,
684684
%i0: index, %i1: index)
685685
-> f32 {
686-
memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
687-
%2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
686+
%arg0_align = memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
687+
%2 = memref.load %arg0_align[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
688688
func.return %2 : f32
689689
}

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ func.func @func_with_assert(%arg0: index, %arg1: index) {
1313
}
1414

1515
// CHECK-LABEL: func @func_with_assume_alignment(
16-
// CHECK: memref.assume_alignment %arg0, 64 : memref<128xi8>
16+
// CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
1717
func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
18-
memref.assume_alignment %arg0, 64 : memref<128xi8>
18+
%0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
1919
return
2020
}

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
6363

6464
func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
6565
%0 = memref.alloc() : memref<3x125xi4>
66-
memref.assume_alignment %0, 64 : memref<3x125xi4>
67-
%1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
66+
%align0 =memref.assume_alignment %0, 64 : memref<3x125xi4>
67+
%1 = memref.load %align0[%arg0,%arg1] : memref<3x125xi4>
6868
return %1 : i4
6969
}
7070
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
@@ -73,9 +73,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
7373
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
7474
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
7575
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
76-
// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
76+
// CHECK: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
7777
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
78-
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
78+
// CHECK: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
7979
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
8080
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
8181
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -88,9 +88,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
8888
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
8989
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
9090
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
91-
// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
91+
// CHECK32: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
9292
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
93-
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
93+
// CHECK32: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
9494
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
9595
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
9696
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -350,16 +350,16 @@ func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
350350

351351
func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
352352
%0 = memref.alloc() : memref<3x125xi4>
353-
memref.assume_alignment %0, 64 : memref<3x125xi4>
354-
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
353+
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
354+
memref.store %arg2, %align0[%arg0,%arg1] : memref<3x125xi4>
355355
return
356356
}
357357
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
358358
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
359359
// CHECK: func @memref_store_i4_rank2(
360360
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
361361
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
362-
// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
362+
// CHECK-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
363363
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
364364
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
365365
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -369,16 +369,16 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
369369
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
370370
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
371371
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
372-
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
373-
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
372+
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
373+
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
374374
// CHECK: return
375375

376376
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
377377
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
378378
// CHECK32: func @memref_store_i4_rank2(
379379
// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
380380
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
381-
// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
381+
// CHECK32-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
382382
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
383383
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
384384
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -388,8 +388,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
388388
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
389389
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
390390
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
391-
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
392-
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
391+
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
392+
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
393393
// CHECK32: return
394394

395395
// -----

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ func.func @invalid_memref_cast() {
878878
// alignment is not power of 2.
879879
func.func @assume_alignment(%0: memref<4x4xf16>) {
880880
// expected-error@+1 {{alignment must be power of 2}}
881-
memref.assume_alignment %0, 12 : memref<4x4xf16>
881+
%1 = memref.assume_alignment %0, 12 : memref<4x4xf16>
882882
return
883883
}
884884

@@ -887,7 +887,7 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
887887
// 0 alignment value.
888888
func.func @assume_alignment(%0: memref<4x4xf16>) {
889889
// expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
890-
memref.assume_alignment %0, 0 : memref<4x4xf16>
890+
%1 = memref.assume_alignment %0, 0 : memref<4x4xf16>
891891
return
892892
}
893893

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
284284
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
285285
func.func @assume_alignment(%0: memref<4x4xf16>) {
286286
// CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
287-
memref.assume_alignment %0, 16 : memref<4x4xf16>
287+
%1 = memref.assume_alignment %0, 16 : memref<4x4xf16>
288288
return
289289
}
290290

0 commit comments

Comments
 (0)