@@ -48,8 +48,7 @@ template <typename SYCLType> static bool isMemRefOf(const Type &type) {
48
48
}
49
49
50
50
// Returns the element type of 'memref<?xSYCLType>'.
51
- template <typename SYCLType>
52
- static SYCLType getElementType (const Type &type) {
51
+ template <typename SYCLType> static SYCLType getElementType (const Type &type) {
53
52
assert (isMemRefOf<SYCLType>(type) && " Expecting memref<?xsycl::<type>>" );
54
53
Type elemType = type.cast <MemRefType>().getElementType ();
55
54
return elemType.cast <SYCLType>();
@@ -121,36 +120,74 @@ static Optional<Type> convertRangeType(sycl::RangeType type,
121
120
converter);
122
121
}
123
122
123
+ // / Create a LLVM struct type with name \p name, and the converted \p body as
124
+ // / the body.
125
+ static Optional<Type> convertBodyType (StringRef name,
126
+ llvm::ArrayRef<mlir::Type> body,
127
+ LLVMTypeConverter &converter) {
128
+ auto convertedTy =
129
+ LLVM::LLVMStructType::getIdentified (&converter.getContext (), name);
130
+ if (!convertedTy.isInitialized ()) {
131
+ SmallVector<Type> convertedElemTypes;
132
+ convertedElemTypes.reserve (body.size ());
133
+ if (failed (converter.convertTypes (body, convertedElemTypes)))
134
+ return llvm::None;
135
+ if (failed (convertedTy.setBody (convertedElemTypes, /* isPacked=*/ false )))
136
+ return llvm::None;
137
+ }
138
+
139
+ return convertedTy;
140
+ }
141
+
124
142
// / Converts SYCL accessor implement device type to LLVM type.
125
143
static Optional<Type>
126
144
convertAccessorImplDeviceType (sycl::AccessorImplDeviceType type,
127
145
LLVMTypeConverter &converter) {
128
- SmallVector<Type> convertedElemTypes;
129
- convertedElemTypes.reserve (type.getBody ().size ());
130
- if (failed (converter.convertTypes (type.getBody (), convertedElemTypes)))
131
- return llvm::None;
132
-
133
- return LLVM::LLVMStructType::getNewIdentified (
134
- &converter.getContext (), " class.cl::sycl::detail::AccessorImplDevice" ,
135
- convertedElemTypes, /* isPacked=*/ false );
146
+ return convertBodyType (" class.cl::sycl::detail::AccessorImplDevice" +
147
+ std::to_string (type.getDimension ()),
148
+ type.getBody (), converter);
136
149
}
137
150
138
151
// / Converts SYCL accessor type to LLVM type.
139
152
static Optional<Type> convertAccessorType (sycl::AccessorType type,
140
153
LLVMTypeConverter &converter) {
141
- SmallVector<Type> convertedElemTypes;
142
- convertedElemTypes.reserve (type.getBody ().size ());
143
- if (failed (converter.convertTypes (type.getBody (), convertedElemTypes)))
144
- return llvm::None;
154
+ auto convertedTy = LLVM::LLVMStructType::getIdentified (
155
+ &converter.getContext (),
156
+ " class.cl::sycl::accessor" + std::to_string (type.getDimension ()));
157
+ if (!convertedTy.isInitialized ()) {
158
+ SmallVector<Type> convertedElemTypes;
159
+ convertedElemTypes.reserve (type.getBody ().size ());
160
+ if (failed (converter.convertTypes (type.getBody (), convertedElemTypes)))
161
+ return llvm::None;
162
+
163
+ auto ptrTy = LLVM::LLVMPointerType::get (type.getType (), /* addressSpace=*/ 1 );
164
+ auto structTy =
165
+ LLVM::LLVMStructType::getLiteral (&converter.getContext (), ptrTy);
166
+ convertedElemTypes.push_back (structTy);
167
+
168
+ if (failed (convertedTy.setBody (convertedElemTypes, /* isPacked=*/ false )))
169
+ return llvm::None;
170
+ }
171
+
172
+ return convertedTy;
173
+ }
145
174
146
- auto ptrTy = LLVM::LLVMPointerType::get (type.getType (), /* addressSpace=*/ 1 );
147
- auto structTy =
148
- LLVM::LLVMStructType::getLiteral (&converter.getContext (), ptrTy);
149
- convertedElemTypes.push_back (structTy);
175
+ // / Converts SYCL item base type to LLVM type.
176
+ static Optional<Type> convertItemBaseType (sycl::ItemBaseType type,
177
+ LLVMTypeConverter &converter) {
178
+ return convertBodyType (" class.cl::sycl::detail::ItemBase." +
179
+ std::to_string (type.getDimension ()) +
180
+ (type.getWithOffset () ? " .true" : " .false" ),
181
+ type.getBody (), converter);
182
+ }
150
183
151
- return LLVM::LLVMStructType::getNewIdentified (
152
- &converter.getContext (), " class.cl::sycl::accessor" , convertedElemTypes,
153
- /* isPacked=*/ false );
184
+ // / Converts SYCL item type to LLVM type.
185
+ static Optional<Type> convertItemType (sycl::ItemType type,
186
+ LLVMTypeConverter &converter) {
187
+ return convertBodyType (" class.cl::sycl::item." +
188
+ std::to_string (type.getDimension ()) +
189
+ (type.getWithOffset () ? " .true" : " .false" ),
190
+ type.getBody (), converter);
154
191
}
155
192
156
193
// ===----------------------------------------------------------------------===//
@@ -188,7 +225,7 @@ class ConstructorPattern final
188
225
MLIRContext *context = module .getContext ();
189
226
190
227
// Lookup the ctor function to use.
191
- const auto ®istry = SYCLFuncRegistry::create (module , rewriter);
228
+ const auto ®istry = SYCLFuncRegistry::create (module , rewriter);
192
229
auto voidTy = LLVM::LLVMVoidType::get (context);
193
230
SYCLFuncDescriptor::FuncId funcId =
194
231
registry.getFuncId (SYCLFuncDescriptor::FuncIdKind::IdCtor, voidTy,
@@ -235,12 +272,10 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
235
272
typeConverter.addConversion (
236
273
[&](sycl::IDType type) { return convertIDType (type, typeConverter); });
237
274
typeConverter.addConversion ([&](sycl::ItemBaseType type) {
238
- llvm_unreachable (" SYCLToLLVM - sycl::ItemBaseType not handle (yet)" );
239
- return llvm::None;
275
+ return convertItemBaseType (type, typeConverter);
240
276
});
241
277
typeConverter.addConversion ([&](sycl::ItemType type) {
242
- llvm_unreachable (" SYCLToLLVM - sycl::ItemType not handle (yet)" );
243
- return llvm::None;
278
+ return convertItemType (type, typeConverter);
244
279
});
245
280
typeConverter.addConversion ([&](sycl::NdItemType type) {
246
281
llvm_unreachable (" SYCLToLLVM - sycl::NdItemType not handle (yet)" );
0 commit comments