Skip to content

Commit 58ea538

Browse files
authored
[mlir][memref] Add a folder for chained AssumeAlignmentOp ops. (#142425)
The chained ops can be folded away when they have the same alignment. Signed-off-by: hanhanW <[email protected]>
1 parent f393986 commit 58ea538

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
174174
}];
175175

176176
let hasVerifier = 1;
177+
let hasFolder = 1;
177178
}
178179

179180
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,15 @@ void AssumeAlignmentOp::getAsmResultNames(
533533
setNameFn(getResult(), "assume_align");
534534
}
535535

536+
OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
537+
auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
538+
if (!source)
539+
return {};
540+
if (source.getAlignment() != getAlignment())
541+
return {};
542+
return getMemref();
543+
}
544+
536545
//===----------------------------------------------------------------------===//
537546
// CastOp
538547
//===----------------------------------------------------------------------===//

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,3 +1177,16 @@ func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32>) -> memref<?x?xf32,
11771177
// CHECK: return %[[TRANSPOSE]]
11781178
return %transpose : memref<?x?xf32, #transpose_map>
11791179
}
1180+
1181+
// -----
1182+
1183+
// CHECK-LABEL: func @fold_assume_alignment_chain
1184+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1185+
func.func @fold_assume_alignment_chain(%0: memref<128xf32>) -> memref<128xf32> {
1186+
// CHECK: %[[ALIGN:.+]] = memref.assume_alignment %[[ARG0]], 16
1187+
%1 = memref.assume_alignment %0, 16 : memref<128xf32>
1188+
// CHECK-NOT: memref.assume_alignment
1189+
%2 = memref.assume_alignment %1, 16 : memref<128xf32>
1190+
// CHECK: return %[[ALIGN]]
1191+
return %2 : memref<128xf32>
1192+
}

0 commit comments

Comments
 (0)