|
29 | 29 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
30 | 30 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
31 | 31 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
| 32 | +#include "mlir/Dialect/SYCL/IR/SYCLOpsTypes.h" |
32 | 33 | #include "mlir/IR/BlockAndValueMapping.h"
|
33 | 34 | #include "mlir/IR/ImplicitLocOpBuilder.h"
|
34 | 35 | #include "mlir/Transforms/RegionUtils.h"
|
@@ -111,6 +112,38 @@ struct SubIndexOpLowering : public ConvertOpToLLVMPattern<SubIndexOp> {
|
111 | 112 | targetMemRef.setOffset(rewriter, loc, baseOffset);
|
112 | 113 | }
|
113 | 114 |
|
| 115 | + if (auto ST = sourceMemRefType.getElementType() |
| 116 | + .dyn_cast<mlir::LLVM::LLVMStructType>()) { |
| 117 | + // According to MLIRASTConsumer::getMLIRType() in clang-mlir.cc, memref of |
| 118 | + // struct type is only generated for struct that has at least one entry of |
| 119 | + // SYCL type, otherwise a llvm pointer type is generated instead of a |
| 120 | + // memref. |
| 121 | + assert(any_of(ST.getBody(), |
| 122 | + [](Type Element) { |
| 123 | + return Element |
| 124 | + .isa<mlir::sycl::IDType, mlir::sycl::AccessorType, |
| 125 | + mlir::sycl::RangeType, |
| 126 | + mlir::sycl::AccessorImplDeviceType, |
| 127 | + mlir::sycl::ArrayType, mlir::sycl::ItemType, |
| 128 | + mlir::sycl::ItemBaseType, mlir::sycl::NdItemType, |
| 129 | + mlir::sycl::GroupType>(); |
| 130 | + }) && |
| 131 | + "Expecting at least one element type of the struct to be a SYCL " |
| 132 | + "type"); |
| 133 | + // According to MLIRScanner::InitializeValueByInitListExpr() in |
| 134 | + // clang-mlir.cc, when a memref element type is a struct type, the return |
| 135 | + // type of a polygeist.subindex should be a memref of the element type of |
| 136 | + // the struct. |
| 137 | + auto elemPtrTy = LLVM::LLVMPointerType::get( |
| 138 | + getTypeConverter()->convertType(viewMemRefType.getElementType())); |
| 139 | + auto gep = rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, prev, idxs); |
| 140 | + MemRefDescriptor nexRef = createMemRefDescriptor( |
| 141 | + loc, viewMemRefType, gep, gep, sizes, strides, rewriter); |
| 142 | + |
| 143 | + rewriter.replaceOp(subViewOp, {nexRef}); |
| 144 | + return success(); |
| 145 | + } |
| 146 | + |
114 | 147 | assert(getTypeConverter()->convertType(viewMemRefType.getElementType()) ==
|
115 | 148 | prev.getType().cast<LLVM::LLVMPointerType>().getElementType() &&
|
116 | 149 | "Expecting the element types to match");
|
|
0 commit comments