Skip to content

Commit 7b3ac03

Browse files
whitneywhtsangetiotto
authored andcommitted
Add type converters for sycl::item and sycl::item_base (#52)
The runtime class of `sycl::item_base`: ``` template <int Dims> struct ItemBase<Dims, true> { ... range<Dims> MExtent; id<Dims> MIndex; id<Dims> MOffset; ... } template <int Dims> struct ItemBase<Dims, false> { ... range<Dims> MExtent; id<Dims> MIndex; ... } ``` The runtime class of `sycl::item`: ``` template <int dimensions = 1, bool with_offset = true> class item { ... detail::ItemBase<dimensions, with_offset> MImpl; ... } ``` Example of LLVM IR generated directly from clang: ``` %"class.cl::sycl::item" = type { %"struct.cl::sycl::detail::ItemBase" } %"struct.cl::sycl::detail::ItemBase" = type { %"class.cl::sycl::range", %"class.cl::sycl::id", %"class.cl::sycl::id" } ``` Signed-off-by: Tsang, Whitney <[email protected]>
1 parent c244dcf commit 7b3ac03

File tree

2 files changed

+73
-26
lines changed

2 files changed

+73
-26
lines changed

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ template <typename SYCLType> static bool isMemRefOf(const Type &type) {
4848
}
4949

5050
// Returns the element type of 'memref<?xSYCLType>'.
51-
template <typename SYCLType>
52-
static SYCLType getElementType(const Type &type) {
51+
template <typename SYCLType> static SYCLType getElementType(const Type &type) {
5352
assert(isMemRefOf<SYCLType>(type) && "Expecting memref<?xsycl::<type>>");
5453
Type elemType = type.cast<MemRefType>().getElementType();
5554
return elemType.cast<SYCLType>();
@@ -121,36 +120,74 @@ static Optional<Type> convertRangeType(sycl::RangeType type,
121120
converter);
122121
}
123122

123+
/// Create a LLVM struct type with name \p name, and the converted \p body as
124+
/// the body.
125+
static Optional<Type> convertBodyType(StringRef name,
126+
llvm::ArrayRef<mlir::Type> body,
127+
LLVMTypeConverter &converter) {
128+
auto convertedTy =
129+
LLVM::LLVMStructType::getIdentified(&converter.getContext(), name);
130+
if (!convertedTy.isInitialized()) {
131+
SmallVector<Type> convertedElemTypes;
132+
convertedElemTypes.reserve(body.size());
133+
if (failed(converter.convertTypes(body, convertedElemTypes)))
134+
return llvm::None;
135+
if (failed(convertedTy.setBody(convertedElemTypes, /*isPacked=*/false)))
136+
return llvm::None;
137+
}
138+
139+
return convertedTy;
140+
}
141+
124142
/// Converts SYCL accessor implement device type to LLVM type.
125143
static Optional<Type>
126144
convertAccessorImplDeviceType(sycl::AccessorImplDeviceType type,
127145
LLVMTypeConverter &converter) {
128-
SmallVector<Type> convertedElemTypes;
129-
convertedElemTypes.reserve(type.getBody().size());
130-
if (failed(converter.convertTypes(type.getBody(), convertedElemTypes)))
131-
return llvm::None;
132-
133-
return LLVM::LLVMStructType::getNewIdentified(
134-
&converter.getContext(), "class.cl::sycl::detail::AccessorImplDevice",
135-
convertedElemTypes, /*isPacked=*/false);
146+
return convertBodyType("class.cl::sycl::detail::AccessorImplDevice" +
147+
std::to_string(type.getDimension()),
148+
type.getBody(), converter);
136149
}
137150

138151
/// Converts SYCL accessor type to LLVM type.
139152
static Optional<Type> convertAccessorType(sycl::AccessorType type,
140153
LLVMTypeConverter &converter) {
141-
SmallVector<Type> convertedElemTypes;
142-
convertedElemTypes.reserve(type.getBody().size());
143-
if (failed(converter.convertTypes(type.getBody(), convertedElemTypes)))
144-
return llvm::None;
154+
auto convertedTy = LLVM::LLVMStructType::getIdentified(
155+
&converter.getContext(),
156+
"class.cl::sycl::accessor" + std::to_string(type.getDimension()));
157+
if (!convertedTy.isInitialized()) {
158+
SmallVector<Type> convertedElemTypes;
159+
convertedElemTypes.reserve(type.getBody().size());
160+
if (failed(converter.convertTypes(type.getBody(), convertedElemTypes)))
161+
return llvm::None;
162+
163+
auto ptrTy = LLVM::LLVMPointerType::get(type.getType(), /*addressSpace=*/1);
164+
auto structTy =
165+
LLVM::LLVMStructType::getLiteral(&converter.getContext(), ptrTy);
166+
convertedElemTypes.push_back(structTy);
167+
168+
if (failed(convertedTy.setBody(convertedElemTypes, /*isPacked=*/false)))
169+
return llvm::None;
170+
}
171+
172+
return convertedTy;
173+
}
145174

146-
auto ptrTy = LLVM::LLVMPointerType::get(type.getType(), /*addressSpace=*/1);
147-
auto structTy =
148-
LLVM::LLVMStructType::getLiteral(&converter.getContext(), ptrTy);
149-
convertedElemTypes.push_back(structTy);
175+
/// Converts SYCL item base type to LLVM type.
176+
static Optional<Type> convertItemBaseType(sycl::ItemBaseType type,
177+
LLVMTypeConverter &converter) {
178+
return convertBodyType("class.cl::sycl::detail::ItemBase." +
179+
std::to_string(type.getDimension()) +
180+
(type.getWithOffset() ? ".true" : ".false"),
181+
type.getBody(), converter);
182+
}
150183

151-
return LLVM::LLVMStructType::getNewIdentified(
152-
&converter.getContext(), "class.cl::sycl::accessor", convertedElemTypes,
153-
/*isPacked=*/false);
184+
/// Converts SYCL item type to LLVM type.
185+
static Optional<Type> convertItemType(sycl::ItemType type,
186+
LLVMTypeConverter &converter) {
187+
return convertBodyType("class.cl::sycl::item." +
188+
std::to_string(type.getDimension()) +
189+
(type.getWithOffset() ? ".true" : ".false"),
190+
type.getBody(), converter);
154191
}
155192

156193
//===----------------------------------------------------------------------===//
@@ -188,7 +225,7 @@ class ConstructorPattern final
188225
MLIRContext *context = module.getContext();
189226

190227
// Lookup the ctor function to use.
191-
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
228+
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
192229
auto voidTy = LLVM::LLVMVoidType::get(context);
193230
SYCLFuncDescriptor::FuncId funcId =
194231
registry.getFuncId(SYCLFuncDescriptor::FuncIdKind::IdCtor, voidTy,
@@ -235,12 +272,10 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
235272
typeConverter.addConversion(
236273
[&](sycl::IDType type) { return convertIDType(type, typeConverter); });
237274
typeConverter.addConversion([&](sycl::ItemBaseType type) {
238-
llvm_unreachable("SYCLToLLVM - sycl::ItemBaseType not handle (yet)");
239-
return llvm::None;
275+
return convertItemBaseType(type, typeConverter);
240276
});
241277
typeConverter.addConversion([&](sycl::ItemType type) {
242-
llvm_unreachable("SYCLToLLVM - sycl::ItemType not handle (yet)");
243-
return llvm::None;
278+
return convertItemType(type, typeConverter);
244279
});
245280
typeConverter.addConversion([&](sycl::NdItemType type) {
246281
llvm_unreachable("SYCLToLLVM - sycl::NdItemType not handle (yet)");

mlir-sycl/test/Conversion/SYCLToLLVM/sycl-types-to-llvm.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
// CHECK: llvm.func @test_accessorImplDevice(%arg0: !llvm.[[ACCESSORIMPLDEVICE_1:struct<"class.cl::sycl::detail::AccessorImplDevice.*", \(]][[ID_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
99
// CHECK: llvm.func @test_accessor.1(%arg0: !llvm.[[ACCESSOR_1:struct<"class.cl::sycl::accessor.*", \(]][[ACCESSORIMPLDEVICE_1]][[ID_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]], struct<(ptr<i32, 1>)>)>)
1010
// CHECK: llvm.func @test_accessor.2(%arg0: !llvm.[[ACCESSOR_2:struct<"class.cl::sycl::accessor.*", \(]][[ACCESSORIMPLDEVICE_2:struct<"class.cl::sycl::detail::AccessorImplDevice.*", \(]][[ID_2:struct<"class.cl::sycl::id.*", \(]][[ARRAY_2]][[SUFFIX]], [[RANGE_2]][[ARRAY_2]][[SUFFIX]], [[RANGE_2]][[ARRAY_2]][[SUFFIX]][[SUFFIX]], struct<(ptr<i64, 1>)>)>)
11+
// CHECK: llvm.func @test_item_base.true(%arg0: !llvm.[[ITEM_BASE_1_TRUE:struct<"class.cl::sycl::detail::ItemBase.1.true", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
12+
// CHECK: llvm.func @test_item_base.false(%arg0: !llvm.[[ITEM_BASE_2_FALSE:struct<"class.cl::sycl::detail::ItemBase.2.false", \(]][[RANGE_2]][[ARRAY_2]][[SUFFIX]], [[ID_2]][[ARRAY_2]][[SUFFIX]][[SUFFIX]])
13+
// CHECK: llvm.func @test_item(%arg0: !llvm.[[ITEM_1_TRUE:struct<"class.cl::sycl::item.1.true", \(]][[ITEM_BASE_1_TRUE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])
1114

1215
module {
1316
func.func @test_array.1(%arg0: !sycl.array<[1], (memref<1xi64>)>) {
@@ -34,4 +37,13 @@ module {
3437
func.func @test_accessor.2(%arg0: !sycl.accessor<[2, i64, write, global_buffer], (!sycl.accessor_impl_device<[2], (!sycl.id<2>, !sycl.range<2>, !sycl.range<2>)>)>) {
3538
return
3639
}
40+
func.func @test_item_base.true(%arg0: !sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>) {
41+
return
42+
}
43+
func.func @test_item_base.false(%arg0: !sycl.item_base<[2, false], (!sycl.range<2>, !sycl.id<2>)>) {
44+
return
45+
}
46+
func.func @test_item(%arg0: !sycl.item<[1, true], (!sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>) {
47+
return
48+
}
3749
}

0 commit comments

Comments
 (0)