Skip to content

Commit 435bb50

Browse files
committed
Revert "[mlir][MemRef] Changed AssumeAlignment into a Pure ViewLikeOp (llvm#139521)"
This reverts commit ffb9bbf.
1 parent 20a9a98 commit 435bb50

File tree

11 files changed

+41
-87
lines changed

11 files changed

+41
-87
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -142,37 +142,22 @@ class AllocLikeOp<string mnemonic,
142142
// AssumeAlignmentOp
143143
//===----------------------------------------------------------------------===//
144144

145-
def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
146-
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
147-
Pure,
148-
ViewLikeOpInterface,
149-
SameOperandsAndResultType
150-
]> {
145+
def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
151146
let summary =
152-
"assumption that gives alignment information to the input memref";
147+
"assertion that gives alignment information to the input memref";
153148
let description = [{
154-
The `assume_alignment` operation takes a memref and an integer alignment
155-
value. It returns a new SSA value of the same memref type, but associated
156-
with the assumption that the underlying buffer is aligned to the given
157-
alignment.
149+
The `assume_alignment` operation takes a memref and an integer of alignment
150+
value, and internally annotates the buffer with the given alignment. If
151+
the buffer isn't aligned to the given alignment, the behavior is undefined.
158152

159-
If the buffer isn't aligned to the given alignment, its result is poison.
160-
This operation doesn't affect the semantics of a program where the
161-
alignment assumption holds true. It is intended for optimization purposes,
162-
allowing the compiler to generate more efficient code based on the
163-
alignment assumption. The optimization is best-effort.
153+
This operation doesn't affect the semantics of a correct program. It's for
154+
optimization only, and the optimization is best-effort.
164155
}];
165156
let arguments = (ins AnyMemRef:$memref,
166157
ConfinedAttr<I32Attr, [IntPositive]>:$alignment);
167-
let results = (outs AnyMemRef:$result);
158+
let results = (outs);
168159

169160
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
170-
let extraClassDeclaration = [{
171-
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
172-
173-
Value getViewSource() { return getMemref(); }
174-
}];
175-
176161
let hasVerifier = 1;
177162
}
178163

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,8 @@ struct AssumeAlignmentOpLowering
435435
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
436436
rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
437437
alignmentConst);
438-
rewriter.replaceOp(op, memref);
438+
439+
rewriter.eraseOp(op);
439440
return success();
440441
}
441442
};

mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ using namespace mlir::gpu;
4444
// The functions below provide interface-like verification, but are too specific
4545
// to barrier elimination to become interfaces.
4646

47+
/// Implement the MemoryEffectsOpInterface in the suitable way.
48+
static bool isKnownNoEffectsOpWithoutInterface(Operation *op) {
49+
// memref::AssumeAlignment is conceptually pure, but marking it as such would
50+
// make DCE immediately remove it.
51+
return isa<memref::AssumeAlignmentOp>(op);
52+
}
53+
4754
/// Returns `true` if the op is defines the parallel region that is subject to
4855
/// barrier synchronization.
4956
static bool isParallelRegionBoundary(Operation *op) {
@@ -94,6 +101,10 @@ collectEffects(Operation *op,
94101
if (ignoreBarriers && isa<BarrierOp>(op))
95102
return true;
96103

104+
// Skip over ops that we know have no effects.
105+
if (isKnownNoEffectsOpWithoutInterface(op))
106+
return true;
107+
97108
// Collect effect instances the operation. Note that the implementation of
98109
// getEffects erases all effect instances that have the type other than the
99110
// template parameter so we collect them first in a local buffer and then

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -527,11 +527,6 @@ LogicalResult AssumeAlignmentOp::verify() {
527527
return success();
528528
}
529529

530-
void AssumeAlignmentOp::getAsmResultNames(
531-
function_ref<void(Value, StringRef)> setNameFn) {
532-
setNameFn(getResult(), "assume_align");
533-
}
534-
535530
//===----------------------------------------------------------------------===//
536531
// CastOp
537532
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ struct ConvertMemRefAssumeAlignment final
229229
}
230230

231231
rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232-
op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
232+
op, adaptor.getMemref(), adaptor.getAlignmentAttr());
233233
return success();
234234
}
235235
};

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -919,35 +919,6 @@ struct ExtractStridedMetadataOpGetGlobalFolder
919919
}
920920
};
921921

922-
/// Pattern to replace `extract_strided_metadata(assume_alignment)`
923-
///
924-
/// With
925-
/// \verbatim
926-
/// extract_strided_metadata(memref)
927-
/// \endverbatim
928-
///
929-
/// Since `assume_alignment` is a view-like op that does not modify the
930-
/// underlying buffer, offset, sizes, or strides, extracting strided metadata
931-
/// from its result is equivalent to extracting it from its source. This
932-
/// canonicalization removes the unnecessary indirection.
933-
struct ExtractStridedMetadataOpAssumeAlignmentFolder
934-
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
935-
public:
936-
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
937-
938-
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
939-
PatternRewriter &rewriter) const override {
940-
auto assumeAlignmentOp =
941-
op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
942-
if (!assumeAlignmentOp)
943-
return failure();
944-
945-
rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(
946-
op, assumeAlignmentOp.getViewSource());
947-
return success();
948-
}
949-
};
950-
951922
/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
952923
/// source of the ViewLikeOp.
953924
class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
@@ -1214,7 +1185,6 @@ void memref::populateExpandStridedMetadataPatterns(
12141185
ExtractStridedMetadataOpSubviewFolder,
12151186
ExtractStridedMetadataOpCastFolder,
12161187
ExtractStridedMetadataOpMemorySpaceCastFolder,
1217-
ExtractStridedMetadataOpAssumeAlignmentFolder,
12181188
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
12191189
patterns.getContext());
12201190
}
@@ -1231,7 +1201,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
12311201
ExtractStridedMetadataOpReinterpretCastFolder,
12321202
ExtractStridedMetadataOpCastFolder,
12331203
ExtractStridedMetadataOpMemorySpaceCastFolder,
1234-
ExtractStridedMetadataOpAssumeAlignmentFolder,
12351204
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
12361205
patterns.getContext());
12371206
}

mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ func.func @load_and_assume(
683683
%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>,
684684
%i0: index, %i1: index)
685685
-> f32 {
686-
%arg0_align = memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
687-
%2 = memref.load %arg0_align[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
686+
memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
687+
%2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
688688
func.return %2 : f32
689689
}

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,4 @@ func.func @func_with_assert(%arg0: index, %arg1: index) {
1010
%0 = arith.cmpi slt, %arg0, %arg1 : index
1111
cf.assert %0, "%arg0 must be less than %arg1"
1212
return
13-
}
14-
15-
// CHECK-LABEL: func @func_with_assume_alignment(
16-
// CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
17-
func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
18-
%0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
19-
return
2013
}

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
6363

6464
func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
6565
%0 = memref.alloc() : memref<3x125xi4>
66-
%align0 =memref.assume_alignment %0, 64 : memref<3x125xi4>
67-
%1 = memref.load %align0[%arg0,%arg1] : memref<3x125xi4>
66+
memref.assume_alignment %0, 64 : memref<3x125xi4>
67+
%1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
6868
return %1 : i4
6969
}
7070
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
@@ -73,9 +73,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
7373
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
7474
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
7575
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
76-
// CHECK: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
76+
// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
7777
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
78-
// CHECK: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
78+
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
7979
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
8080
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
8181
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -88,9 +88,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
8888
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
8989
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
9090
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
91-
// CHECK32: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
91+
// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
9292
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
93-
// CHECK32: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
93+
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
9494
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
9595
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
9696
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -350,16 +350,16 @@ func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
350350

351351
func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
352352
%0 = memref.alloc() : memref<3x125xi4>
353-
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
354-
memref.store %arg2, %align0[%arg0,%arg1] : memref<3x125xi4>
353+
memref.assume_alignment %0, 64 : memref<3x125xi4>
354+
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
355355
return
356356
}
357357
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
358358
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
359359
// CHECK: func @memref_store_i4_rank2(
360360
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
361361
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
362-
// CHECK-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
362+
// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
363363
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
364364
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
365365
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -369,16 +369,16 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
369369
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
370370
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
371371
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
372-
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
373-
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
372+
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
373+
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
374374
// CHECK: return
375375

376376
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
377377
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
378378
// CHECK32: func @memref_store_i4_rank2(
379379
// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
380380
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
381-
// CHECK32-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
381+
// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
382382
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
383383
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
384384
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -388,8 +388,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
388388
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
389389
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
390390
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
391-
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
392-
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
391+
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
392+
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
393393
// CHECK32: return
394394

395395
// -----

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ func.func @invalid_memref_cast() {
878878
// alignment is not power of 2.
879879
func.func @assume_alignment(%0: memref<4x4xf16>) {
880880
// expected-error@+1 {{alignment must be power of 2}}
881-
%1 = memref.assume_alignment %0, 12 : memref<4x4xf16>
881+
memref.assume_alignment %0, 12 : memref<4x4xf16>
882882
return
883883
}
884884

@@ -887,7 +887,7 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
887887
// 0 alignment value.
888888
func.func @assume_alignment(%0: memref<4x4xf16>) {
889889
// expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
890-
%1 = memref.assume_alignment %0, 0 : memref<4x4xf16>
890+
memref.assume_alignment %0, 0 : memref<4x4xf16>
891891
return
892892
}
893893

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
284284
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
285285
func.func @assume_alignment(%0: memref<4x4xf16>) {
286286
// CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
287-
%1 = memref.assume_alignment %0, 16 : memref<4x4xf16>
287+
memref.assume_alignment %0, 16 : memref<4x4xf16>
288288
return
289289
}
290290

0 commit comments

Comments
 (0)