Skip to content

Commit 9656ebf

Browse files
whitneywhtsangetiotto
authored andcommitted
Add type converters for sycl::group and sycl::nd_item (#54)
The runtime class of `sycl::group`: ``` template <int Dimensions = 1> class group { ... range<Dimensions> globalRange; range<Dimensions> localRange; range<Dimensions> groupRange; id<Dimensions> index; ... } ``` The runtime class of `sycl::nd_item`: ``` template <int dimensions = 1> class nd_item { ... item<dimensions, true> globalItem; item<dimensions, false> localItem; group<dimensions> Group; ... } ``` Example of LLVM IR generated directly from clang: ``` %"class.cl::sycl::group" = type { %"class.cl::sycl::range", %"class.cl::sycl::range", %"class.cl::sycl::range", %"class.cl::sycl::id" } %"class.cl::sycl::nd_item" = type { %"class.cl::sycl::item", %"class.cl::sycl::item.0", %"class.cl::sycl::group" } %"class.cl::sycl::item" = type { %"struct.cl::sycl::detail::ItemBase" } %"class.cl::sycl::item.0" = type { %"struct.cl::sycl::detail::ItemBase.1" } %"struct.cl::sycl::detail::ItemBase" = type { %"class.cl::sycl::range", %"class.cl::sycl::id", %"class.cl::sycl::id" } %"struct.cl::sycl::detail::ItemBase.1" = type { %"class.cl::sycl::range", %"class.cl::sycl::id" } ``` Signed-off-by: Tsang, Whitney <[email protected]>
1 parent 133be13 commit 9656ebf

File tree

2 files changed

+81
-55
lines changed

2 files changed

+81
-55
lines changed

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -74,52 +74,6 @@ static Optional<Type> getArrayTy(MLIRContext &context, unsigned dimNum,
7474
// Type conversion
7575
//===----------------------------------------------------------------------===//
7676

77-
/// Converts SYCL array type to LLVM type.
78-
static Optional<Type> convertArrayType(sycl::ArrayType type,
79-
LLVMTypeConverter &converter) {
80-
assert(type.getBody().size() == 1 &&
81-
"Expecting SYCL array body to have size 1");
82-
assert(type.getBody()[0].isa<MemRefType>() &&
83-
"Expecting SYCL array body entry to be MemRefType");
84-
assert(type.getBody()[0].cast<MemRefType>().getElementType() ==
85-
converter.getIndexType() &&
86-
"Expecting SYCL array body entry element type to be the index type");
87-
return getArrayTy(converter.getContext(), type.getDimension(),
88-
converter.getIndexType());
89-
}
90-
91-
/// Converts SYCL range or id type to LLVM type, given \p dimNum - number of
92-
/// dimensions, \p name - the expected LLVM type name, \p converter - LLVM type
93-
/// converter.
94-
static Optional<Type> convertRangeOrIDTy(unsigned dimNum, StringRef name,
95-
LLVMTypeConverter &converter) {
96-
auto convertedTy = LLVM::LLVMStructType::getIdentified(
97-
&converter.getContext(), name.str() + "." + std::to_string(dimNum));
98-
if (!convertedTy.isInitialized()) {
99-
auto arrayTy =
100-
getArrayTy(converter.getContext(), dimNum, converter.getIndexType());
101-
if (!arrayTy.hasValue())
102-
return llvm::None;
103-
if (failed(convertedTy.setBody(arrayTy.getValue(), /*isPacked=*/false)))
104-
return llvm::None;
105-
}
106-
return convertedTy;
107-
}
108-
109-
/// Converts SYCL id type to LLVM type.
110-
static Optional<Type> convertIDType(sycl::IDType type,
111-
LLVMTypeConverter &converter) {
112-
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::id",
113-
converter);
114-
}
115-
116-
/// Converts SYCL range type to LLVM type.
117-
static Optional<Type> convertRangeType(sycl::RangeType type,
118-
LLVMTypeConverter &converter) {
119-
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::range",
120-
converter);
121-
}
122-
12377
/// Create a LLVM struct type with name \p name, and the converted \p body as
12478
/// the body.
12579
static Optional<Type> convertBodyType(StringRef name,
@@ -172,6 +126,53 @@ static Optional<Type> convertAccessorType(sycl::AccessorType type,
172126
return convertedTy;
173127
}
174128

129+
/// Converts SYCL array type to LLVM type.
130+
static Optional<Type> convertArrayType(sycl::ArrayType type,
131+
LLVMTypeConverter &converter) {
132+
assert(type.getBody().size() == 1 &&
133+
"Expecting SYCL array body to have size 1");
134+
assert(type.getBody()[0].isa<MemRefType>() &&
135+
"Expecting SYCL array body entry to be MemRefType");
136+
assert(type.getBody()[0].cast<MemRefType>().getElementType() ==
137+
converter.getIndexType() &&
138+
"Expecting SYCL array body entry element type to be the index type");
139+
return getArrayTy(converter.getContext(), type.getDimension(),
140+
converter.getIndexType());
141+
}
142+
143+
/// Converts SYCL group type to LLVM type.
144+
static Optional<Type> convertGroupType(sycl::GroupType type,
145+
LLVMTypeConverter &converter) {
146+
return convertBodyType("class.cl::sycl::group." +
147+
std::to_string(type.getDimension()),
148+
type.getBody(), converter);
149+
}
150+
151+
/// Converts SYCL range or id type to LLVM type, given \p dimNum - number of
152+
/// dimensions, \p name - the expected LLVM type name, \p converter - LLVM type
153+
/// converter.
154+
static Optional<Type> convertRangeOrIDTy(unsigned dimNum, StringRef name,
155+
LLVMTypeConverter &converter) {
156+
auto convertedTy = LLVM::LLVMStructType::getIdentified(
157+
&converter.getContext(), name.str() + "." + std::to_string(dimNum));
158+
if (!convertedTy.isInitialized()) {
159+
auto arrayTy =
160+
getArrayTy(converter.getContext(), dimNum, converter.getIndexType());
161+
if (!arrayTy.hasValue())
162+
return llvm::None;
163+
if (failed(convertedTy.setBody(arrayTy.getValue(), /*isPacked=*/false)))
164+
return llvm::None;
165+
}
166+
return convertedTy;
167+
}
168+
169+
/// Converts SYCL id type to LLVM type.
170+
static Optional<Type> convertIDType(sycl::IDType type,
171+
LLVMTypeConverter &converter) {
172+
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::id",
173+
converter);
174+
}
175+
175176
/// Converts SYCL item base type to LLVM type.
176177
static Optional<Type> convertItemBaseType(sycl::ItemBaseType type,
177178
LLVMTypeConverter &converter) {
@@ -190,6 +191,21 @@ static Optional<Type> convertItemType(sycl::ItemType type,
190191
type.getBody(), converter);
191192
}
192193

194+
/// Converts SYCL nd item type to LLVM type.
195+
static Optional<Type> convertNdItemType(sycl::NdItemType type,
196+
LLVMTypeConverter &converter) {
197+
return convertBodyType("class.cl::sycl::nd_item." +
198+
std::to_string(type.getDimension()),
199+
type.getBody(), converter);
200+
}
201+
202+
/// Converts SYCL range type to LLVM type.
203+
static Optional<Type> convertRangeType(sycl::RangeType type,
204+
LLVMTypeConverter &converter) {
205+
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::range",
206+
converter);
207+
}
208+
193209
//===----------------------------------------------------------------------===//
194210
// ConstructorPattern - Converts `sycl.constructor` to LLVM.
195211
//===----------------------------------------------------------------------===//
@@ -263,8 +279,7 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
263279
return convertArrayType(type, typeConverter);
264280
});
265281
typeConverter.addConversion([&](sycl::GroupType type) {
266-
llvm_unreachable("SYCLToLLVM - sycl::GroupType not handle (yet)");
267-
return llvm::None;
282+
return convertGroupType(type, typeConverter);
268283
});
269284
typeConverter.addConversion(
270285
[&](sycl::IDType type) { return convertIDType(type, typeConverter); });
@@ -275,8 +290,7 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
275290
return convertItemType(type, typeConverter);
276291
});
277292
typeConverter.addConversion([&](sycl::NdItemType type) {
278-
llvm_unreachable("SYCLToLLVM - sycl::NdItemType not handle (yet)");
279-
return llvm::None;
293+
return convertNdItemType(type, typeConverter);
280294
});
281295
typeConverter.addConversion([&](sycl::RangeType type) {
282296
return convertRangeType(type, typeConverter);

mlir-sycl/test/Conversion/SYCLToLLVM/sycl-types-to-llvm.mlir

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
// CHECK: llvm.func @test_accessorImplDevice(%arg0: !llvm.[[ACCESSORIMPLDEVICE_1:struct<"class.cl::sycl::detail::AccessorImplDevice.*", \(]][[ID_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
99
// CHECK: llvm.func @test_accessor.1(%arg0: !llvm.[[ACCESSOR_1:struct<"class.cl::sycl::accessor.*", \(]][[ACCESSORIMPLDEVICE_1]][[ID_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]], struct<(ptr<i32, 1>)>)>)
1010
// CHECK: llvm.func @test_accessor.2(%arg0: !llvm.[[ACCESSOR_2:struct<"class.cl::sycl::accessor.*", \(]][[ACCESSORIMPLDEVICE_2:struct<"class.cl::sycl::detail::AccessorImplDevice.*", \(]][[ID_2:struct<"class.cl::sycl::id.*", \(]][[ARRAY_2]][[SUFFIX]], [[RANGE_2]][[ARRAY_2]][[SUFFIX]], [[RANGE_2]][[ARRAY_2]][[SUFFIX]][[SUFFIX]], struct<(ptr<i64, 1>)>)>)
11-
// CHECK: llvm.func @test_item_base.true(%arg0: !llvm.[[ITEM_BASE_1_TRUE:struct<"class.cl::sycl::detail::ItemBase.1.true", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
12-
// CHECK: llvm.func @test_item_base.false(%arg0: !llvm.[[ITEM_BASE_2_FALSE:struct<"class.cl::sycl::detail::ItemBase.2.false", \(]][[RANGE_2]][[ARRAY_2]][[SUFFIX]], [[ID_2]][[ARRAY_2]][[SUFFIX]][[SUFFIX]])
13-
// CHECK: llvm.func @test_item(%arg0: !llvm.[[ITEM_1_TRUE:struct<"class.cl::sycl::item.1.true", \(]][[ITEM_BASE_1_TRUE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])
11+
// CHECK: llvm.func @test_item_base.true(%arg0: !llvm.[[ITEM_BASE_1_TRUE:struct<"class.cl::sycl::detail::ItemBase.*", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
12+
// CHECK: llvm.func @test_item_base.false(%arg0: !llvm.[[ITEM_BASE_1_FALSE:struct<"class.cl::sycl::detail::ItemBase.*", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
13+
// CHECK: llvm.func @test_item.true(%arg0: !llvm.[[ITEM_1_TRUE:struct<"class.cl::sycl::item.*", \(]][[ITEM_BASE_1_TRUE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])
14+
// CHECK: llvm.func @test_item.false(%arg0: !llvm.[[ITEM_1_FALSE:struct<"class.cl::sycl::item.*", \(]][[ITEM_BASE_1_FALSE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])
15+
// CHECK: llvm.func @test_group(%arg0: !llvm.[[GROUP_1:struct<"class.cl::sycl::group.*", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
16+
// CHECK: llvm.func @test_nd_item(%arg0: !llvm.[[ND_ITEM_1:struct<"class.cl::sycl::nd_item.*", \(]][[ITEM_1_TRUE]][[ITEM_BASE_1_TRUE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]], [[ITEM_1_FALSE]][[ITEM_BASE_1_FALSE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]], [[GROUP_1]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])
1417

1518
module {
1619
func.func @test_array.1(%arg0: !sycl.array<[1], (memref<1xi64>)>) {
@@ -40,10 +43,19 @@ module {
4043
func.func @test_item_base.true(%arg0: !sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>) {
4144
return
4245
}
43-
func.func @test_item_base.false(%arg0: !sycl.item_base<[2, false], (!sycl.range<2>, !sycl.id<2>)>) {
46+
func.func @test_item_base.false(%arg0: !sycl.item_base<[1, false], (!sycl.range<1>, !sycl.id<1>)>) {
4447
return
4548
}
46-
func.func @test_item(%arg0: !sycl.item<[1, true], (!sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>) {
49+
func.func @test_item.true(%arg0: !sycl.item<[1, true], (!sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>) {
50+
return
51+
}
52+
func.func @test_item.false(%arg0: !sycl.item<[1, false], (!sycl.item_base<[1, false], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>) {
53+
return
54+
}
55+
func.func @test_group(%arg0: !sycl.group<[1], (!sycl.range<1>, !sycl.range<1>, !sycl.range<1>, !sycl.id<1>)>) {
56+
return
57+
}
58+
func.func @test_nd_item(%arg0: !sycl.nd_item<[1], (!sycl.item<[1, true], (!sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>, !sycl.item<[1, false], (!sycl.item_base<[1, false], (!sycl.range<1>, !sycl.id<1>)>)>, !sycl.group<[1], (!sycl.range<1>, !sycl.range<1>, !sycl.range<1>, !sycl.id<1>)>)>) {
4759
return
4860
}
4961
}

0 commit comments

Comments
 (0)