Skip to content

Commit f3676c3

Browse files
committed
[mlir][memref] memref.reinterpret_cast folding
* reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x) * reinterpret_cast(cast(x)) -> reinterpret_cast(x) * reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets are 0 Differential Revision: https://reviews.llvm.org/D120242
1 parent dbc32e2 commit f3676c3

File tree

5 files changed

+77
-0
lines changed

5 files changed

+77
-0
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,8 @@ def MemRef_ReinterpretCastOp
10901090
/// and `strides` operands.
10911091
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
10921092
}];
1093+
1094+
let hasFolder = 1;
10931095
}
10941096

10951097
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
5353
/// If ofr is a constant integer or an IntegerAttr, return the integer.
5454
Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
5555

56+
/// Return true if `ofr` is constant integer equal to `value`.
57+
bool isConstantIntValue(OpFoldResult ofr, int64_t value);
58+
5659
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
5760
/// or the same SSA value.
5861
/// Ignore integer bitwitdh and type mismatch that come from the fact there is

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,6 +1508,36 @@ LogicalResult ReinterpretCastOp::verify() {
15081508
return success();
15091509
}
15101510

1511+
OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
1512+
Value src = source();
1513+
auto getPrevSrc = [&]() -> Value {
1514+
// reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1515+
if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1516+
return prev.source();
1517+
1518+
// reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1519+
if (auto prev = src.getDefiningOp<CastOp>())
1520+
return prev.source();
1521+
1522+
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1523+
// are 0.
1524+
if (auto prev = src.getDefiningOp<SubViewOp>())
1525+
if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1526+
return isConstantIntValue(val, 0);
1527+
}))
1528+
return prev.source();
1529+
1530+
return nullptr;
1531+
};
1532+
1533+
if (auto prevSrc = getPrevSrc()) {
1534+
sourceMutable().assign(prevSrc);
1535+
return getResult();
1536+
}
1537+
1538+
return nullptr;
1539+
}
1540+
15111541
//===----------------------------------------------------------------------===//
15121542
// Reassociative reshape ops
15131543
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
8181
return llvm::None;
8282
}
8383

84+
/// Return true if `ofr` is constant integer equal to `value`.
85+
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
86+
auto val = getConstantIntValue(ofr);
87+
return val && *val == value;
88+
}
89+
8490
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
8591
/// or the same SSA value.
8692
/// Ignore integer bitwidth and type mismatch that come from the fact there is

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,39 @@ func @scopeInline(%arg : memref<index>) {
657657

658658
// CHECK: func @scopeInline
659659
// CHECK-NOT: memref.alloca_scope
660+
661+
// -----
662+
663+
// CHECK-LABEL: func @reinterpret_of_reinterpret
664+
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
665+
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
666+
// CHECK: return %[[RES]]
667+
func @reinterpret_of_reinterpret(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> {
668+
%0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size1], strides: [1] : memref<?xi8> to memref<?xi8>
669+
%1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8>
670+
return %1 : memref<?xi8>
671+
}
672+
673+
// -----
674+
675+
// CHECK-LABEL: func @reinterpret_of_cast
676+
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE:.*]]: index)
677+
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE]]], strides: [1]
678+
// CHECK: return %[[RES]]
679+
func @reinterpret_of_cast(%arg : memref<?xi8>, %size: index) -> memref<?xi8> {
680+
%0 = memref.cast %arg : memref<?xi8> to memref<5xi8>
681+
%1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size], strides: [1] : memref<5xi8> to memref<?xi8>
682+
return %1 : memref<?xi8>
683+
}
684+
685+
// -----
686+
687+
// CHECK-LABEL: func @reinterpret_of_subview
688+
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
689+
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
690+
// CHECK: return %[[RES]]
691+
func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> {
692+
%0 = memref.subview %arg[0] [%size1] [1] : memref<?xi8> to memref<?xi8>
693+
%1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8>
694+
return %1 : memref<?xi8>
695+
}

0 commit comments

Comments
 (0)