Skip to content

Commit 2759ead

Browse files
[SYCL-MLIR] Fix memref2pointer(subindex) optimization (#7408)
Optimization `Memref2PointerIndex` is only valid when the element type of the source of subindex is not a struct type. After this PR, `stream-copy.cpp` and `stream-triad.cpp` can both run successfully. They are added in intel/llvm-test-suite#1384. Signed-off-by: Tsang, Whitney <[email protected]>
1 parent 29755dc commit 2759ead

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

polygeist/lib/polygeist/Ops.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,7 @@ class Memref2Pointer2MemrefCast final
11591159
return success();
11601160
}
11611161
};
1162-
/// Simplify pointer2memref(memref2pointer(x)) to cast(x)
1162+
/// Simplify memref2pointer(subindex(x)) to getelementptr(memref2pointer(x))
11631163
class Memref2PointerIndex final : public OpRewritePattern<Memref2PointerOp> {
11641164
public:
11651165
using OpRewritePattern<Memref2PointerOp>::OpRewritePattern;
@@ -1173,9 +1173,12 @@ class Memref2PointerIndex final : public OpRewritePattern<Memref2PointerOp> {
11731173
if (src.getSource().getType().cast<MemRefType>().getShape().size() != 1)
11741174
return failure();
11751175

1176+
auto MET = src.getSource().getType().cast<MemRefType>().getElementType();
1177+
if (MET.isa<LLVM::LLVMStructType>())
1178+
return failure();
1179+
11761180
Value idx[] = {src.getIndex()};
11771181
auto PET = op.getType().cast<LLVM::LLVMPointerType>().getElementType();
1178-
auto MET = src.getSource().getType().cast<MemRefType>().getElementType();
11791182
if (PET != MET) {
11801183
auto ps = rewriter.create<polygeist::TypeSizeOp>(
11811184
op.getLoc(), rewriter.getIndexType(), mlir::TypeAttr::get(PET));

polygeist/test/polygeist-opt/invalid_canonicalization.mlir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func.func @SubToCast(%arg0: memref<?x!llvm.struct<(i32)>>) -> memref<?xi32> {
3535
// CHECK-NEXT: [[C1:%.*]] = arith.constant 0 : i32
3636
// CHECK-NEXT: [[T0:%.*]] = "polygeist.subindex"([[A0]], [[C0]]) : (memref<?x!llvm.struct<(i32)>>, index) -> memref<?xi32>
3737
// CHECK-NEXT: memref.store [[C1]], [[T0]][[[C0]]] : memref<?xi32>
38-
// CHECK-NEXT: return %0 : memref<?xi32>
38+
// CHECK-NEXT: return [[T0]] : memref<?xi32>
3939
// CHECK-NEXT: }
4040
func.func @SimplifySubIndexUsers(%arg0: memref<?x!llvm.struct<(i32)>>) -> memref<?xi32> {
4141
%c0 = arith.constant 0 : index
@@ -46,3 +46,19 @@ func.func @SimplifySubIndexUsers(%arg0: memref<?x!llvm.struct<(i32)>>) -> memref
4646
}
4747

4848
// -----
49+
50+
// CHECK: func.func @Memref2PointerIndex([[A0:%.*]]: memref<?x!llvm.struct<(i32, i32)>>) -> !llvm.ptr<i32> {
51+
// CHECK-NEXT: [[C1:%.*]] = arith.constant 1 : index
52+
// CHECK-NEXT: [[T0:%.*]] = "polygeist.subindex"([[A0]], [[C1]]) : (memref<?x!llvm.struct<(i32, i32)>>, index) -> memref<?xi32>
53+
// CHECK-NEXT: [[T1:%.*]] = "polygeist.memref2pointer"([[T0]]) : (memref<?xi32>) -> !llvm.ptr<i32>
54+
// CHECK-NEXT: return [[T1]] : !llvm.ptr<i32>
55+
// CHECK-NEXT: }
56+
57+
func.func @Memref2PointerIndex(%arg0: memref<?x!llvm.struct<(i32, i32)>>) -> !llvm.ptr<i32> {
58+
%c1 = arith.constant 1 : index
59+
%0 = "polygeist.subindex"(%arg0, %c1) : (memref<?x!llvm.struct<(i32, i32)>>, index) -> memref<?xi32>
60+
%1 = "polygeist.memref2pointer"(%0) : (memref<?xi32>) -> !llvm.ptr<i32>
61+
return %1 : !llvm.ptr<i32>
62+
}
63+
64+
// -----

0 commit comments

Comments
 (0)