Skip to content

Commit c244dcf

Browse files
committed
Conversion of 'sycl.constructor(%0, %1) {type = @id}' to a copy ctor. (#50)
This PR adds support for converting a `sycl.constructor(%0, %1) {type = @id}` operation (representing a copy construction) to a call to the `sycl::id<n>` copy constructor (`sycl::id<n>::id(const sycl::id<n> const&)`). Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 87e10ab commit c244dcf

File tree

4 files changed

+148
-33
lines changed

4 files changed

+148
-33
lines changed

mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class SYCLFuncDescriptor {
3535
/// Enumerates SYCL functions.
3636
// clang-format off
3737
enum class FuncId {
38+
Unknown,
39+
3840
// Member functions for the sycl:id<n> class.
3941
Id1CtorDefault, // sycl::id<1>::id()
4042
Id2CtorDefault, // sycl::id<2>::id()
@@ -48,11 +50,20 @@ class SYCLFuncDescriptor {
4850
Id1CtorItem, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long)
4951
Id2CtorItem, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long)
5052
Id3CtorItem, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long)
51-
53+
Id1CopyCtor, // sycl::id<1>::id(sycl::id<1> const&)
54+
Id2CopyCtor, // sycl::id<2>::id(sycl::id<2> const&)
55+
Id3CopyCtor, // sycl::id<3>::id(sycl::id<3> const&)
56+
5257
// Member functions for ..TODO..
5358
};
5459
// clang-format on
5560

61+
/// Enumerates the kind of FuncId.
62+
enum class FuncIdKind {
63+
Unknown,
64+
IdCtor, // any sycl::id<n> constructors
65+
};
66+
5667
// Call the SYCL constructor identified by \p id with the given \p args.
5768
static Value call(FuncId id, ValueRange args,
5869
const SYCLFuncRegistry &registry, OpBuilder &b,
@@ -65,9 +76,12 @@ class SYCLFuncDescriptor {
6576
: id(id), name(name), outputTy(outputTy),
6677
argTys(argTys.begin(), argTys.end()) {}
6778

68-
// Inject the declaration for this function into the module.
79+
/// Inject the declaration for this function into the module.
6980
void declareFunction(ModuleOp &module, OpBuilder &b);
7081

82+
/// Returns true if the given \p funcId is for a sycl::id<n> constructor.
83+
static bool isIdCtor(FuncId funcId);
84+
7185
private:
7286
FuncId id; // unique identifier for a SYCL function
7387
StringRef name; // SYCL function name
@@ -86,12 +100,18 @@ class SYCLFuncRegistry {
86100
static const SYCLFuncRegistry create(ModuleOp &module, OpBuilder &builder);
87101

88102
const SYCLFuncDescriptor &getFuncDesc(SYCLFuncDescriptor::FuncId id) const {
89-
assert(
90-
(registry.find(id) != registry.end()) &&
91-
"function identified by 'id' not found in the SYCL function registry");
103+
assert((registry.find(id) != registry.end()) &&
104+
"function identified by 'id' not found in the SYCL function "
105+
"registry");
92106
return registry.at(id);
93107
}
94108

109+
// Returns the SYCLFuncDescriptor::FuncId corresponding to the function
110+
// descriptor that matches the given signature and funcIdKind.
111+
SYCLFuncDescriptor::FuncId
112+
getFuncId(SYCLFuncDescriptor::FuncIdKind funcIdKind, Type retType,
113+
TypeRange argTypes) const;
114+
95115
private:
96116
SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder);
97117

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,34 @@ void SYCLFuncDescriptor::declareFunction(ModuleOp &module, OpBuilder &b) {
3232
funcRef = builder.getOrInsertFuncDecl(name, outputTy, argTys, module);
3333
}
3434

35-
Value SYCLFuncDescriptor::call(FuncId id, ValueRange args,
35+
bool SYCLFuncDescriptor::isIdCtor(FuncId funcId) {
36+
switch (funcId) {
37+
case FuncId::Id1CtorDefault:
38+
case FuncId::Id2CtorDefault:
39+
case FuncId::Id3CtorDefault:
40+
case FuncId::Id1CtorSizeT:
41+
case FuncId::Id2CtorSizeT:
42+
case FuncId::Id3CtorSizeT:
43+
case FuncId::Id1CtorRange:
44+
case FuncId::Id2CtorRange:
45+
case FuncId::Id3CtorRange:
46+
case FuncId::Id1CtorItem:
47+
case FuncId::Id2CtorItem:
48+
case FuncId::Id3CtorItem:
49+
case FuncId::Id1CopyCtor:
50+
case FuncId::Id2CopyCtor:
51+
case FuncId::Id3CopyCtor:
52+
return true;
53+
default:;
54+
}
55+
56+
return false;
57+
}
58+
59+
Value SYCLFuncDescriptor::call(FuncId funcId, ValueRange args,
3660
const SYCLFuncRegistry &registry, OpBuilder &b,
3761
Location loc) {
38-
const SYCLFuncDescriptor &funcDesc = registry.getFuncDesc(id);
62+
const SYCLFuncDescriptor &funcDesc = registry.getFuncDesc(funcId);
3963
LLVM_DEBUG(
4064
llvm::dbgs() << "Creating SYCLFuncDescriptor::call to funcDesc.funcRef: "
4165
<< funcDesc.funcRef << "\n");
@@ -59,14 +83,56 @@ Value SYCLFuncDescriptor::call(FuncId id, ValueRange args,
5983

6084
SYCLFuncRegistry *SYCLFuncRegistry::instance = nullptr;
6185

62-
const SYCLFuncRegistry SYCLFuncRegistry::create(
63-
ModuleOp &module, OpBuilder &builder) {
86+
const SYCLFuncRegistry SYCLFuncRegistry::create(ModuleOp &module,
87+
OpBuilder &builder) {
6488
if (!instance)
6589
instance = new SYCLFuncRegistry(module, builder);
6690

6791
return *instance;
6892
}
6993

94+
SYCLFuncDescriptor::FuncId
95+
SYCLFuncRegistry::getFuncId(SYCLFuncDescriptor::FuncIdKind funcIdKind,
96+
Type retType, TypeRange argTypes) const {
97+
assert(funcIdKind != SYCLFuncDescriptor::FuncIdKind::Unknown &&
98+
"Invalid funcIdKind");
99+
100+
// Determines whether the given funcId has kind that matches the given
101+
// funcIdKind.
102+
auto kindMatches = [](SYCLFuncDescriptor::FuncId funcId,
103+
SYCLFuncDescriptor::FuncIdKind funcIdKind) {
104+
bool foundMatch = false;
105+
switch (funcIdKind) {
106+
case SYCLFuncDescriptor::FuncIdKind::IdCtor:
107+
foundMatch = SYCLFuncDescriptor::isIdCtor(funcId);
108+
break;
109+
default:
110+
foundMatch = false;
111+
}
112+
return foundMatch;
113+
};
114+
115+
for (const auto &entry : registry) {
116+
// Skip through entries that do not match the requested funcIdKind.
117+
if (!kindMatches(entry.second.id, funcIdKind))
118+
continue;
119+
120+
// Ensure that the entry has return and arguments type that match the one
121+
// provided.
122+
if (retType != entry.second.outputTy ||
123+
argTypes.size() != entry.second.argTys.size())
124+
continue;
125+
if (!std::equal(argTypes.begin(), argTypes.end(),
126+
entry.second.argTys.begin()))
127+
continue;
128+
129+
return entry.second.id;
130+
}
131+
132+
llvm_unreachable("Unimplemented descriptor");
133+
return SYCLFuncDescriptor::FuncId::Unknown;
134+
}
135+
70136
SYCLFuncRegistry::SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder)
71137
: registry() {
72138
MLIRContext *context = module.getContext();
@@ -144,6 +210,22 @@ SYCLFuncRegistry::SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder)
144210
SYCLFuncDescriptor::FuncId::Id3CtorItem,
145211
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm",
146212
voidTy, {id3PtrTy, i64Ty, i64Ty, i64Ty}),
213+
214+
// cl::sycl::id<1>::id(cl::sycl::id<1> const&)
215+
SYCLFuncDescriptor(
216+
SYCLFuncDescriptor::FuncId::Id1CopyCtor,
217+
"_ZN2cl4sycl2idILi1EEC1ERKS2_",
218+
voidTy, {id1PtrTy, id1PtrTy}),
219+
// cl::sycl::id<2>::id(cl::sycl::id<2> const&)
220+
SYCLFuncDescriptor(
221+
SYCLFuncDescriptor::FuncId::Id2CopyCtor,
222+
"_ZN2cl4sycl2idILi2EEC1ERKS2_",
223+
voidTy, {id2PtrTy, id2PtrTy}),
224+
// cl::sycl::id<3>::id(cl::sycl::id<3> const&)
225+
SYCLFuncDescriptor(
226+
SYCLFuncDescriptor::FuncId::Id3CopyCtor,
227+
"_ZN2cl4sycl2idILi3EEC1ERKS2_",
228+
voidTy, {id3PtrTy, id3PtrTy}),
147229
};
148230
// clang-format on
149231

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -185,32 +185,14 @@ class ConstructorPattern final
185185
llvm::dbgs() << "\n");
186186

187187
ModuleOp module = op.getOperation()->getParentOfType<ModuleOp>();
188-
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
189-
ValueRange args(op.Args());
190-
assert(args.size() > 0 && "Expecting at least one argument (the this ptr)");
191-
192-
// Multikey map used to lookup the specific sycl::id's ctor to call.
193-
// Given 'sycl::id<dim>(args)' the key is the pair {size(args), dim}.
194-
const std::map<std::pair<int, int>, SYCLFuncDescriptor::FuncId>
195-
lookupCtorId = {
196-
{{1, 1}, SYCLFuncDescriptor::FuncId::Id1CtorDefault},
197-
{{1, 2}, SYCLFuncDescriptor::FuncId::Id2CtorDefault},
198-
{{1, 3}, SYCLFuncDescriptor::FuncId::Id3CtorDefault},
199-
{{2, 1}, SYCLFuncDescriptor::FuncId::Id1CtorSizeT},
200-
{{2, 2}, SYCLFuncDescriptor::FuncId::Id2CtorSizeT},
201-
{{2, 3}, SYCLFuncDescriptor::FuncId::Id3CtorSizeT},
202-
{{3, 1}, SYCLFuncDescriptor::FuncId::Id1CtorRange},
203-
{{3, 2}, SYCLFuncDescriptor::FuncId::Id2CtorRange},
204-
{{3, 3}, SYCLFuncDescriptor::FuncId::Id3CtorRange},
205-
{{4, 1}, SYCLFuncDescriptor::FuncId::Id1CtorItem},
206-
{{4, 2}, SYCLFuncDescriptor::FuncId::Id2CtorItem},
207-
{{4, 3}, SYCLFuncDescriptor::FuncId::Id3CtorItem},
208-
};
188+
MLIRContext *context = module.getContext();
209189

210190
// Lookup the ctor function to use.
211-
auto arg0ElemTy = getElementType<mlir::sycl::IDType>(args[0].getType());
212-
auto key = std::make_pair(args.size(), arg0ElemTy.getDimension());
213-
SYCLFuncDescriptor::FuncId funcId = lookupCtorId.at(key);
191+
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
192+
auto voidTy = LLVM::LLVMVoidType::get(context);
193+
SYCLFuncDescriptor::FuncId funcId =
194+
registry.getFuncId(SYCLFuncDescriptor::FuncIdKind::IdCtor, voidTy,
195+
opAdaptor.Args().getTypes());
214196

215197
// Generate an LLVM call to the appropriate ctor.
216198
SYCLFuncDescriptor::call(funcId, opAdaptor.getOperands(), registry,

mlir-sycl/test/Conversion/SYCLToLLVM/func-ops-to-llvm.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,34 @@ func.func @id3CtorItem(%arg0: memref<?x!sycl.id<3>>, %arg1: i64, %arg2: i64, %ar
124124
}
125125

126126
// -----
127+
128+
//===-------------------------------------------------------------------------------------------------===//
129+
// Constructors sycl::id<n>::id(sycl::id<n> const&)
130+
//===-------------------------------------------------------------------------------------------------===//
131+
132+
// CHECK: llvm.func @_ZN2cl4sycl2idILi1EEC1ERKS2_([[THIS_PTR_TYPE:!llvm.struct<\(ptr<struct<"class.cl::sycl::id.1",.*]], [[THIS_PTR_TYPE]])
133+
func.func @id1CopyCtor(%arg0: memref<?x!sycl.id<1>>, %arg1: memref<?x!sycl.id<1>>) {
134+
// CHECK: llvm.call @_ZN2cl4sycl2idILi1EEC1ERKS2_({{.*}}, {{.*}}) : ([[THIS_PTR_TYPE]], [[THIS_PTR_TYPE]]) -> ()
135+
"sycl.constructor"(%arg0, %arg1) {Type = @id} : (memref<?x!sycl.id<1>>, memref<?x!sycl.id<1>>) -> ()
136+
return
137+
}
138+
139+
// -----
140+
141+
// CHECK: llvm.func @_ZN2cl4sycl2idILi2EEC1ERKS2_([[THIS_PTR_TYPE:!llvm.struct<\(ptr<struct<"class.cl::sycl::id.2",.*]], [[THIS_PTR_TYPE]])
142+
func.func @id2CopyCtor(%arg0: memref<?x!sycl.id<2>>, %arg1: memref<?x!sycl.id<2>>) {
143+
// CHECK: llvm.call @_ZN2cl4sycl2idILi2EEC1ERKS2_({{.*}}, {{.*}}) : ([[THIS_PTR_TYPE]], [[THIS_PTR_TYPE]]) -> ()
144+
"sycl.constructor"(%arg0, %arg1) {Type = @id} : (memref<?x!sycl.id<2>>, memref<?x!sycl.id<2>>) -> ()
145+
return
146+
}
147+
148+
// -----
149+
150+
// CHECK: llvm.func @_ZN2cl4sycl2idILi3EEC1ERKS2_([[THIS_PTR_TYPE:!llvm.struct<\(ptr<struct<"class.cl::sycl::id.3",.*]], [[THIS_PTR_TYPE]])
151+
func.func @id3CopyCtor(%arg0: memref<?x!sycl.id<3>>, %arg1: memref<?x!sycl.id<3>>) {
152+
// CHECK: llvm.call @_ZN2cl4sycl2idILi3EEC1ERKS2_({{.*}}, {{.*}}) : ([[THIS_PTR_TYPE]], [[THIS_PTR_TYPE]]) -> ()
153+
"sycl.constructor"(%arg0, %arg1) {Type = @id} : (memref<?x!sycl.id<3>>, memref<?x!sycl.id<3>>) -> ()
154+
return
155+
}
156+
157+
// -----

0 commit comments

Comments
 (0)