Skip to content

Commit 1d867c2

Browse files
committed
convert sycl.constructor(...) { Type = @id } (#45)
This PR implements the SYCLToLLVM conversion for lowering sycl::id<n>(....). It introduces the class ConstructorPattern which generates an llvm.call to the appropriate ctor function of the sycl::id templated class when given a sycl.constructor(%1) {Type = @id} operation. This PR also introduces a new too names sycl-mlir-opt which can be used to drive unit testing. Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 4a303d6 commit 1d867c2

File tree

23 files changed

+564
-193
lines changed

23 files changed

+564
-193
lines changed

mlir-sycl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules")
7070
add_subdirectory(include/mlir)
7171
add_subdirectory(lib)
7272
add_subdirectory(test)
73+
add_subdirectory(tools)
7374

7475
install(DIRECTORY include/
7576
DESTINATION include

mlir-sycl/include/mlir/Conversion/SYCLPasses.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- SYCLPasses.h - Conversion Pass Construction and Registration -----------===//
1+
//===- SYCLPasses.h - Conversion Pass Construction and Registration -------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -9,14 +9,16 @@
99
#ifndef MLIR_CONVERSION_SYCLPASSES_H
1010
#define MLIR_CONVERSION_SYCLPASSES_H
1111

12-
#include "mlir/Conversion/SYCLToLLVM/SYCLToLLVM.h"
12+
#include "mlir/Conversion/SYCLToLLVM/SYCLToLLVMPass.h"
1313

1414
namespace mlir {
15+
namespace sycl {
1516

1617
/// Generate the code for registering conversion passes.
1718
#define GEN_PASS_REGISTRATION
1819
#include "mlir/Conversion/SYCLPasses.h.inc"
1920

21+
} // namespace sycl
2022
} // namespace mlir
2123

2224
#endif // MLIR_CONVERSION_SYCLPASSES_H

mlir-sycl/include/mlir/Conversion/SYCLPasses.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def ConvertSYCLToLLVM : Pass<"convert-sycl-to-llvm", "ModuleOp"> {
2121
See docs/SYCLToLLVMDialectConversion/ for more details.
2222
TODO: add docs referenced above.
2323
}];
24-
let constructor = "mlir::createConvertSYCLToLLVMPass()";
24+
let constructor = "mlir::sycl::createConvertSYCLToLLVMPass()";
2525
let dependentDialects = ["LLVM::LLVMDialect"];
2626
}
2727

mlir-sycl/include/mlir/Conversion/SYCLToLLVM/DialectBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ class LLVMBuilder : public DialectBuilder {
7575
LLVM::BitcastOp genBitcast(Type type, Value val) const;
7676
LLVM::ExtractValueOp genExtractValue(Type type, Value container,
7777
ArrayRef<int64_t> pos) const;
78-
LLVM::CallOp genCall(FlatSymbolRefAttr funcSym, ArrayRef<Type> resTypes,
79-
ArrayRef<Value> operands) const;
78+
LLVM::CallOp genCall(FlatSymbolRefAttr funcSym, TypeRange resTypes,
79+
ValueRange operands) const;
8080
LLVM::ConstantOp genConstant(Type type, double val) const;
8181
LLVM::SExtOp genSignExtend(Type type, Value val) const;
8282
};

mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class SYCLFuncDescriptor {
5454
// clang-format on
5555

5656
// Call the SYCL constructor identified by \p id with the given \p args.
57-
static Value call(FuncId id, ArrayRef<Value> args,
57+
static Value call(FuncId id, ValueRange args,
5858
const SYCLFuncRegistry &registry, OpBuilder &b,
5959
Location loc);
6060

mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLToLLVM.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_CONVERSION_SYCLTOLLVM_SYCLTOLLVM_H
1515

1616
#include "mlir/Transforms/DialectConversion.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1718

1819
namespace mlir {
1920
class LLVMTypeConverter;

mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLToLLVMPass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
#ifndef MLIR_CONVERSION_SYCLTOLLVM_SYCLTOLLVMPASS_H
1414
#define MLIR_CONVERSION_SYCLTOLLVM_SYCLTOLLVMPASS_H
1515

16+
#include "mlir/Pass/Pass.h"
1617
#include <memory>
1718

1819
namespace mlir {
1920
class ModuleOp;
2021
template <typename T> class OperationPass;
2122

23+
namespace sycl {
24+
2225
/// Creates a pass to convert SYCL operations to the LLVMIR dialect.
2326
std::unique_ptr<OperationPass<ModuleOp>> createConvertSYCLToLLVMPass();
2427

28+
} // namespace sycl
2529
} // namespace mlir
2630

2731
#endif // MLIR_CONVERSION_SYCLTOLLVM_SYCLTOLLVMPASS_H
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,2 @@
1-
# Copyright (C) Codeplay Software Limited
2-
3-
#===--- CMakeLists.txt -----------------------------------------------------===#
4-
#
5-
# MLIR-SYCL is under the Apache License v2.0 with LLVM Exceptions.
6-
# See https://llvm.org/LICENSE.txt for license information.
7-
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8-
#
9-
#===------------------------------------------------------------------------===#
10-
111
add_mlir_dialect(SYCLOps sycl)
122
add_mlir_doc(SYCLOps SYCLOps Dialects/ -gen-op-doc)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name SYCL)
3+
add_public_tablegen_target(MLIRSYCLTransformsIncGen)
4+
add_dependencies(mlir-headers MLIRSYCLTransformsIncGen)
5+
6+
add_mlir_doc(Passes SYCLPasses ./ -gen-pass-doc)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- Passes.h - SYCL Patterns and Passes ---------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header declares patterns and passes on SYCL operations.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_SYCL_TRANSFORMS_PASSES_H
14+
#define MLIR_DIALECT_SYCL_TRANSFORMS_PASSES_H
15+
16+
#include "mlir/Pass/Pass.h"
17+
18+
namespace mlir {
19+
namespace sycl {
20+
21+
//===----------------------------------------------------------------------===//
22+
// Passes
23+
//===----------------------------------------------------------------------===//
24+
25+
26+
//===----------------------------------------------------------------------===//
27+
// Registration
28+
//===----------------------------------------------------------------------===//
29+
30+
#define GEN_PASS_REGISTRATION
31+
#include "mlir/Dialect/SYCL/Transforms/Passes.h.inc"
32+
33+
} // namespace sycl
34+
} // namespace mlir
35+
36+
#endif // MLIR_DIALECT_SYCL_TRANSFORMS_PASSES_H
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- Passes.td - SYCL pass definition file --------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_SYCL_TRANSFORMS_PASSES
10+
#define MLIR_DIALECT_SYCL_TRANSFORMS_PASSES
11+
12+
include "mlir/Pass/PassBase.td"
13+
14+
#endif // MLIR_DIALECT_SYCL_TRANSFORMS_PASSES

mlir-sycl/lib/Conversion/SYCLToLLVM/DialectBuilder.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ LLVM::ExtractValueOp LLVMBuilder::genExtractValue(Type type, Value container,
115115
getI64ArrayAttr(position));
116116
}
117117

118-
LLVM::CallOp LLVMBuilder::genCall(FlatSymbolRefAttr funcSym, ArrayRef<Type> resTypes,
119-
ArrayRef<Value> operands) const {
118+
LLVM::CallOp LLVMBuilder::genCall(FlatSymbolRefAttr funcSym, TypeRange resTypes,
119+
ValueRange operands) const {
120+
assert(funcSym && "Expecting a valid function symbol");
120121
return create<LLVM::CallOp>(resTypes, funcSym, operands);
121122
}
122123

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h"
14+
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
15+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1416
#include "mlir/Conversion/SYCLToLLVM/DialectBuilder.h"
17+
#include "mlir/Conversion/SYCLToLLVM/SYCLToLLVM.h"
18+
#include "mlir/Dialect/SYCL/IR/SYCLOpsTypes.h"
1519
#include "llvm/Support/Debug.h"
1620

1721
#define DEBUG_TYPE "sycl-func-registry"
@@ -25,23 +29,27 @@ using namespace mlir::sycl;
2529

2630
void SYCLFuncDescriptor::declareFunction(ModuleOp &module, OpBuilder &b) {
2731
LLVMBuilder builder(b, module.getLoc());
28-
builder.getOrInsertFuncDecl(name, outputTy, argTys, module);
32+
funcRef = builder.getOrInsertFuncDecl(name, outputTy, argTys, module);
2933
}
3034

31-
Value SYCLFuncDescriptor::call(FuncId id, ArrayRef<Value> args,
35+
Value SYCLFuncDescriptor::call(FuncId id, ValueRange args,
3236
const SYCLFuncRegistry &registry, OpBuilder &b,
3337
Location loc) {
34-
SmallVector<Type, 1> funcOutputTys;
3538
const SYCLFuncDescriptor &funcDesc = registry.getFuncDesc(id);
39+
LLVM_DEBUG(
40+
llvm::dbgs() << "Creating SYCLFuncDescriptor::call to funcDesc.funcRef: "
41+
<< funcDesc.funcRef << "\n");
42+
43+
SmallVector<Type, 4> funcOutputTys;
3644
if (!funcDesc.outputTy.isa<LLVM::LLVMVoidType>())
3745
funcOutputTys.emplace_back(funcDesc.outputTy);
3846

3947
LLVMBuilder builder(b, loc);
4048
LLVM::CallOp callOp = builder.genCall(funcDesc.funcRef, funcOutputTys, args);
41-
// TODO: we could check here the arguments against the function signature and
49+
50+
// TODO: we could check here the arguments against the function signature and
4251
// assert if there is a mismatch.
43-
assert(callOp.getNumResults() == 1 && "expecting a single result");
44-
52+
assert(callOp.getNumResults() <= 1 && "expecting a single result");
4553
return callOp.getResult(0);
4654
}
4755

@@ -62,73 +70,80 @@ const SYCLFuncRegistry SYCLFuncRegistry::create(
6270
SYCLFuncRegistry::SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder)
6371
: registry() {
6472
MLIRContext *context = module.getContext();
73+
LowerToLLVMOptions options(context);
74+
LLVMTypeConverter converter(context, options);
75+
populateSYCLToLLVMTypeConversion(converter);
76+
77+
Type id1PtrTy =
78+
converter.convertType(MemRefType::get(-1, IDType::get(context, 1)));
79+
Type id2PtrTy =
80+
converter.convertType(MemRefType::get(-1, IDType::get(context, 2)));
81+
Type id3PtrTy =
82+
converter.convertType(MemRefType::get(-1, IDType::get(context, 3)));
6583
auto voidTy = LLVM::LLVMVoidType::get(context);
66-
auto i8Ty = IntegerType::get(context, 8);
67-
auto i8PtrTy = LLVM::LLVMPointerType::get(i8Ty);
68-
auto i8PtrPtrTy = LLVM::LLVMPointerType::get(i8PtrTy);
6984
auto i64Ty = IntegerType::get(context, 64);
70-
auto i64PtrTy = LLVM::LLVMPointerType::get(i64Ty);
7185

72-
// Construct the SYCL functions descriptors (enum, function name, signature).
86+
// Construct the SYCL functions descriptors (enum,
87+
// function name, signature).
7388
// clang-format off
7489
std::vector<SYCLFuncDescriptor> descriptors = {
7590
// cl::sycl::id<1>::id()
7691
SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id1CtorDefault,
77-
"_ZN2cl4sycl2idILi1EEC2Ev", voidTy, {i8PtrTy}),
92+
"_ZN2cl4sycl2idILi1EEC2Ev", voidTy, {id1PtrTy}),
7893
// cl::sycl::id<2>::id()
7994
SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id2CtorDefault,
80-
"_ZN2cl4sycl2idILi2EEC2Ev", voidTy, {i8PtrTy}),
95+
"_ZN2cl4sycl2idILi2EEC2Ev", voidTy, {id2PtrTy}),
8196
// cl::sycl::id<3>::id()
8297
SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id3CtorDefault,
83-
"_ZN2cl4sycl2idILi3EEC2Ev", voidTy, {i8PtrTy}),
98+
"_ZN2cl4sycl2idILi3EEC2Ev", voidTy, {id3PtrTy}),
8499

85100
// cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type)
86101
SYCLFuncDescriptor(
87102
SYCLFuncDescriptor::FuncId::Id1CtorSizeT,
88103
"_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE",
89-
voidTy, {i8PtrTy, i64Ty}),
104+
voidTy, {id1PtrTy, i64Ty}),
90105
// cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type)
91106
SYCLFuncDescriptor(
92107
SYCLFuncDescriptor::FuncId::Id2CtorSizeT,
93108
"_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE",
94-
voidTy, {i8PtrTy, i64Ty}),
109+
voidTy, {id2PtrTy, i64Ty}),
95110
// cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type)
96111
SYCLFuncDescriptor(
97112
SYCLFuncDescriptor::FuncId::Id3CtorSizeT,
98113
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE",
99-
voidTy, {i8PtrTy, i64Ty}),
114+
voidTy, {id3PtrTy, i64Ty}),
100115

101116
// cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long)
102117
SYCLFuncDescriptor(
103118
SYCLFuncDescriptor::FuncId::Id1CtorRange,
104119
"_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm",
105-
voidTy, {i8PtrTy, i64Ty, i64Ty}),
120+
voidTy, {id1PtrTy, i64Ty, i64Ty}),
106121
// cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long)
107122
SYCLFuncDescriptor(
108123
SYCLFuncDescriptor::FuncId::Id2CtorRange,
109124
"_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm",
110-
voidTy, {i8PtrTy, i64Ty, i64Ty}),
125+
voidTy, {id2PtrTy, i64Ty, i64Ty}),
111126
// cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long)
112127
SYCLFuncDescriptor(
113128
SYCLFuncDescriptor::FuncId::Id3CtorRange,
114129
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm",
115-
voidTy, {i8PtrTy, i64Ty, i64Ty}),
130+
voidTy, {id3PtrTy, i64Ty, i64Ty}),
116131

117132
// cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long)
118133
SYCLFuncDescriptor(
119134
SYCLFuncDescriptor::FuncId::Id1CtorItem,
120135
"_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm",
121-
voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}),
136+
voidTy, {id1PtrTy, i64Ty, i64Ty, i64Ty}),
122137
// cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long)
123138
SYCLFuncDescriptor(
124139
SYCLFuncDescriptor::FuncId::Id2CtorItem,
125140
"_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm",
126-
voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}),
141+
voidTy, {id2PtrTy, i64Ty, i64Ty, i64Ty}),
127142
// cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long)
128143
SYCLFuncDescriptor(
129144
SYCLFuncDescriptor::FuncId::Id3CtorItem,
130145
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm",
131-
voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}),
146+
voidTy, {id3PtrTy, i64Ty, i64Ty, i64Ty}),
132147
};
133148
// clang-format on
134149

0 commit comments

Comments
 (0)