Skip to content

Commit 178e944

Browse files
whitneywhtsangetiotto
authored andcommitted
Add support of memref of struct type (#44)
`SubIndexOpLowering` was unable to correctly handle memref of struct type. memref of struct type is only generated for struct that has at least one entry of SYCL type, otherwise a llvm pointer type is generated instead of a memref. (ref: https://github.com/InteonCo/Polygeist/blob/main/tools/cgeist/Lib/clang-mlir.cc#L5421) When a memref element type is a struct type, the return type of a `polygeist.subindex` should be a memref of the element type of the struct. (ref: https://github.com/InteonCo/Polygeist/blob/main/tools/cgeist/Lib/clang-mlir.cc#L703) Signed-off-by: Tsang, Whitney <[email protected]>
1 parent 0402ca1 commit 178e944

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

polygeist/lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
3030
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3131
#include "mlir/Dialect/SCF/IR/SCF.h"
32+
#include "mlir/Dialect/SYCL/IR/SYCLOpsTypes.h"
3233
#include "mlir/IR/BlockAndValueMapping.h"
3334
#include "mlir/IR/ImplicitLocOpBuilder.h"
3435
#include "mlir/Transforms/RegionUtils.h"
@@ -111,6 +112,38 @@ struct SubIndexOpLowering : public ConvertOpToLLVMPattern<SubIndexOp> {
111112
targetMemRef.setOffset(rewriter, loc, baseOffset);
112113
}
113114

115+
if (auto ST = sourceMemRefType.getElementType()
116+
.dyn_cast<mlir::LLVM::LLVMStructType>()) {
117+
// According to MLIRASTConsumer::getMLIRType() in clang-mlir.cc, memref of
118+
// struct type is only generated for struct that has at least one entry of
119+
// SYCL type, otherwise a llvm pointer type is generated instead of a
120+
// memref.
121+
assert(any_of(ST.getBody(),
122+
[](Type Element) {
123+
return Element
124+
.isa<mlir::sycl::IDType, mlir::sycl::AccessorType,
125+
mlir::sycl::RangeType,
126+
mlir::sycl::AccessorImplDeviceType,
127+
mlir::sycl::ArrayType, mlir::sycl::ItemType,
128+
mlir::sycl::ItemBaseType, mlir::sycl::NdItemType,
129+
mlir::sycl::GroupType>();
130+
}) &&
131+
"Expecting at least one element type of the struct to be a SYCL "
132+
"type");
133+
// According to MLIRScanner::InitializeValueByInitListExpr() in
134+
// clang-mlir.cc, when a memref element type is a struct type, the return
135+
// type of a polygeist.subindex should be a memref of the element type of
136+
// the struct.
137+
auto elemPtrTy = LLVM::LLVMPointerType::get(
138+
getTypeConverter()->convertType(viewMemRefType.getElementType()));
139+
auto gep = rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, prev, idxs);
140+
MemRefDescriptor nexRef = createMemRefDescriptor(
141+
loc, viewMemRefType, gep, gep, sizes, strides, rewriter);
142+
143+
rewriter.replaceOp(subViewOp, {nexRef});
144+
return success();
145+
}
146+
114147
assert(getTypeConverter()->convertType(viewMemRefType.getElementType()) ==
115148
prev.getType().cast<LLVM::LLVMPointerType>().getElementType() &&
116149
"Expecting the element types to match");
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: polygeist-opt --convert-polygeist-to-llvm %s | FileCheck %s
2+
3+
// CHECK: [[GEP:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr<struct<([[SYCLIDSTRUCT:struct<"class.cl::sycl::id.1"]], {{.*}} -> !llvm.ptr<[[SYCLIDSTRUCT]], {{.*}}
4+
// CHECK: [[MEMREF:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr<[[SYCLIDSTRUCT]], {{.*}}
5+
// CHECK: {{.*}} = llvm.insertvalue [[GEP]], [[MEMREF]][0] : !llvm.struct<(ptr<[[SYCLIDSTRUCT]], {{.*}}
6+
7+
module {
8+
func.func @test(%arg0: memref<?x!llvm.struct<(!sycl.id<1>)>>) -> memref<?x!sycl.id<1>> {
9+
%c0 = arith.constant 0 : index
10+
%0 = "polygeist.subindex"(%arg0, %c0) : (memref<?x!llvm.struct<(!sycl.id<1>)>>, index) -> memref<?x!sycl.id<1>>
11+
return %0 : memref<?x!sycl.id<1>>
12+
}
13+
}

0 commit comments

Comments
 (0)