Skip to content

Commit c50573a

Browse files
[SYCL-MLIR] Fix memref2pointer(subindex) canonicalization for SYCL types (#8320)
SYCL types are lowered to LLVM struct types. `memref2pointer(subindex)` canonicalization is invalid when `subindex` element type is LLVM struct type, so it is also invalid for SYCL types. `KernelFusion/internalize_multi_ptr.cpp` fixed by this PR. Signed-off-by: Tsang, Whitney <[email protected]>
1 parent 704451a commit c50573a

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

polygeist/lib/polygeist/Ops.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,9 @@ class Memref2PointerIndex final : public OpRewritePattern<Memref2PointerOp> {
11731173
return failure();
11741174

11751175
auto MET = src.getSource().getType().cast<MemRefType>().getElementType();
1176-
if (MET.isa<LLVM::LLVMStructType>())
1176+
// SYCL types are lowered to LLVM struct type.
1177+
bool isSYCLTy = MET.getDialect().getNamespace().contains("sycl");
1178+
if (MET.isa<LLVM::LLVMStructType>() || isSYCLTy)
11771179
return failure();
11781180

11791181
Value idx[] = {src.getIndex()};

polygeist/test/polygeist-opt/invalid_canonicalization.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,20 @@ func.func @Memref2PointerIndex(%arg0: memref<?x!llvm.struct<(i32, i32)>>) -> !ll
6262
}
6363

6464
// -----
65+
66+
// CHECK: func.func @Memref2PointerIndexSycl([[A0:%.*]]: memref<?x!sycl_array_1_, 4>) -> !llvm.ptr<i64, 4> {
67+
// CHECK-NEXT: [[C1:%.*]] = arith.constant 1 : index
68+
// CHECK-NEXT: [[T0:%.*]] = "polygeist.subindex"([[A0]], [[C1]]) : (memref<?x!sycl_array_1_, 4>, index) -> memref<?xi64, 4>
69+
// CHECK-NEXT: [[T1:%.*]] = "polygeist.memref2pointer"([[T0]]) : (memref<?xi64, 4>) -> !llvm.ptr<i64, 4>
70+
// CHECK-NEXT: return [[T1]] : !llvm.ptr<i64, 4>
71+
// CHECK-NEXT: }
72+
73+
!sycl_array_1_ = !sycl.array<[1], (memref<1xi64, 4>)>
74+
func.func @Memref2PointerIndexSycl(%arg0: memref<?x!sycl_array_1_, 4>) -> !llvm.ptr<i64, 4> {
75+
%c1 = arith.constant 1 : index
76+
%0 = "polygeist.subindex"(%arg0, %c1) : (memref<?x!sycl.array<[1], (memref<1xi64, 4>)>, 4>, index) -> memref<?xi64, 4>
77+
%1 = "polygeist.memref2pointer"(%0) : (memref<?xi64, 4>) -> !llvm.ptr<i64, 4>
78+
return %1 : !llvm.ptr<i64, 4>
79+
}
80+
81+
// -----

0 commit comments

Comments
 (0)