5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
- #include < vector >
8
+ #include < unordered_set >
9
9
10
- #define GC_GPU_OCL_DEF_ONLY
10
+ #define GC_GPU_OCL_CONST_ONLY
11
11
#include " gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h"
12
12
13
13
#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
17
17
#include " mlir/Dialect/GPU/Transforms/Passes.h"
18
18
19
19
using namespace mlir ;
20
+ using namespace mlir ::gc::gpu;
20
21
21
- namespace mlir {
22
- namespace gc {
22
+ namespace mlir ::gc {
23
23
#define GEN_PASS_DECL_GPUTOGPUOCL
24
24
#define GEN_PASS_DEF_GPUTOGPUOCL
25
25
#include " gc/Transforms/Passes.h.inc"
26
- } // namespace gc
27
- } // namespace mlir
26
+ } // namespace mlir::gc
28
27
29
28
namespace {
30
-
31
29
LLVM::CallOp funcCall (OpBuilder &builder, const StringRef name,
32
30
const Type returnType, const ArrayRef<Type> argTypes,
33
31
const Location loc, const ArrayRef<Value> arguments,
@@ -42,8 +40,10 @@ LLVM::CallOp funcCall(OpBuilder &builder, const StringRef name,
42
40
return builder.create <LLVM::CallOp>(loc, function, arguments);
43
41
}
44
42
45
- // Assuming that the pointer to GcGpuOclContext is passed as the last
46
- // memref<anyType> with zero dims argument of the current function.
43
+ // Assuming that the pointer to the context is passed as the last argument
44
+ // of the current function of type memref<anyType> with zero dims. When lowering
45
+ // to LLVM, the memref arg is replaced with 3 args of types ptr, ptr, i64.
46
+ // Returning the first one.
47
47
Value getCtxPtr (const OpBuilder &rewriter) {
48
48
auto func =
49
49
rewriter.getBlock ()->getParent ()->getParentOfType <LLVM::LLVMFuncOp>();
@@ -55,7 +55,7 @@ struct Helper final {
55
55
Type voidType;
56
56
Type ptrType;
57
57
Type idxType;
58
- mutable std::set<SmallString< 32 > > kernelNames;
58
+ mutable std::unordered_set<std::string > kernelNames;
59
59
60
60
explicit Helper (MLIRContext *ctx, LLVMTypeConverter &converter)
61
61
: converter(converter), voidType(LLVM::LLVMVoidType::get(ctx)),
@@ -81,7 +81,7 @@ struct Helper final {
81
81
rewriter.create <LLVM::StoreOp>(loc, kernelPtrs[i], elementPtr);
82
82
}
83
83
84
- funcCall (rewriter, GC_GPU_OCL_KERNEL_DESTROY , voidType, {idxType, ptrType},
84
+ funcCall (rewriter, GPU_OCL_KERNEL_DESTROY , voidType, {idxType, ptrType},
85
85
loc, {size, kernelPtrsArray});
86
86
}
87
87
};
@@ -117,7 +117,7 @@ struct ConvertAlloc final : ConvertOpPattern<gpu::AllocOp> {
117
117
}
118
118
}
119
119
auto size = helper.idxConstant (rewriter, loc, staticSize);
120
- auto ptr = funcCall (rewriter, GC_GPU_OCL_MALLOC , helper.ptrType ,
120
+ auto ptr = funcCall (rewriter, GPU_OCL_MALLOC , helper.ptrType ,
121
121
{helper.ptrType , helper.idxType }, loc,
122
122
{getCtxPtr (rewriter), size})
123
123
.getResult ();
@@ -158,7 +158,7 @@ struct ConvertAlloc final : ConvertOpPattern<gpu::AllocOp> {
158
158
}
159
159
160
160
size = idxMul (size, helper.idxConstant (rewriter, loc, staticSize));
161
- auto ptr = funcCall (rewriter, GC_GPU_OCL_MALLOC , helper.ptrType ,
161
+ auto ptr = funcCall (rewriter, GPU_OCL_MALLOC , helper.ptrType ,
162
162
{helper.ptrType , helper.idxType }, loc,
163
163
{getCtxPtr (rewriter), size})
164
164
.getResult ();
@@ -194,7 +194,7 @@ struct ConvertDealloc final : ConvertOpPattern<gpu::DeallocOp> {
194
194
auto loc = gpuDealloc.getLoc ();
195
195
MemRefDescriptor dsc (adaptor.getMemref ());
196
196
auto ptr = dsc.allocatedPtr (rewriter, loc);
197
- auto oclDealloc = funcCall (rewriter, GC_GPU_OCL_DEALLOC , helper.voidType ,
197
+ auto oclDealloc = funcCall (rewriter, GPU_OCL_DEALLOC , helper.voidType ,
198
198
{helper.ptrType , helper.ptrType }, loc,
199
199
{getCtxPtr (rewriter), ptr});
200
200
rewriter.replaceOp (gpuDealloc, oclDealloc);
@@ -227,7 +227,7 @@ struct ConvertMemcpy final : ConvertOpPattern<gpu::MemcpyOp> {
227
227
auto dstPtr = dstDsc.alignedPtr (rewriter, loc);
228
228
auto size = helper.idxConstant (rewriter, loc, elementSize * numElements);
229
229
auto oclMemcpy = funcCall (
230
- rewriter, GC_GPU_OCL_MEMCPY , helper.voidType ,
230
+ rewriter, GPU_OCL_MEMCPY , helper.voidType ,
231
231
{helper.ptrType , helper.ptrType , helper.ptrType , helper.idxType }, loc,
232
232
{getCtxPtr (rewriter), srcPtr, dstPtr, size});
233
233
rewriter.replaceOp (gpuMemcpy, oclMemcpy);
@@ -249,7 +249,7 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
249
249
250
250
const Location loc = gpuLaunch.getLoc ();
251
251
auto kernelArgs = adaptor.getKernelOperands ();
252
- std::vector <Value> args;
252
+ SmallVector <Value> args;
253
253
args.reserve (kernelArgs.size () + 2 );
254
254
args.emplace_back (getCtxPtr (rewriter));
255
255
args.emplace_back (kernelPtr.value ());
@@ -265,7 +265,7 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
265
265
}
266
266
267
267
const auto gpuOclLaunch =
268
- funcCall (rewriter, GC_GPU_OCL_KERNEL_LAUNCH , helper.voidType ,
268
+ funcCall (rewriter, GPU_OCL_KERNEL_LAUNCH , helper.voidType ,
269
269
{helper.ptrType , helper.ptrType }, loc, args, true );
270
270
rewriter.replaceOp (gpuLaunch, gpuOclLaunch);
271
271
return success ();
@@ -284,7 +284,9 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
284
284
SmallString<128 > getFuncName (" getGcGpuOclKernel_" );
285
285
getFuncName.append (kernelModName);
286
286
287
- if (helper.kernelNames .insert (SmallString<32 >(kernelModName)).second ) {
287
+ if (helper.kernelNames
288
+ .insert (std::string (kernelModName.begin (), kernelModName.end ()))
289
+ .second ) {
288
290
auto insPoint = rewriter.saveInsertionPoint ();
289
291
SmallString<128 > strBuf (" gcGpuOclKernel_" );
290
292
strBuf.append (kernelModName);
@@ -391,10 +393,10 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
391
393
auto spirv = LLVM::createGlobalString (loc, rewriter, str (" SPIRV" ),
392
394
binaryAttr.getValue (),
393
395
LLVM::Linkage::Internal);
394
- auto spirvSize = rewriter.create <mlir:: LLVM::ConstantOp>(
396
+ auto spirvSize = rewriter.create <LLVM::ConstantOp>(
395
397
loc, helper.idxType ,
396
- mlir:: IntegerAttr::get (helper.idxType ,
397
- static_cast <int64_t >(binaryAttr.size ())));
398
+ IntegerAttr::get (helper.idxType ,
399
+ static_cast <int64_t >(binaryAttr.size ())));
398
400
399
401
SmallVector<int32_t > globalSize;
400
402
SmallVector<int32_t > localSize;
@@ -436,7 +438,7 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
436
438
auto argNum =
437
439
helper.idxConstant (rewriter, loc, adaptor.getKernelOperands ().size ());
438
440
auto createKernelCall = funcCall (
439
- rewriter, GC_GPU_OCL_KERNEL_CREATE , helper.ptrType ,
441
+ rewriter, GPU_OCL_KERNEL_CREATE , helper.ptrType ,
440
442
{helper.ptrType , helper.idxType , helper.ptrType , helper.ptrType ,
441
443
helper.ptrType , helper.ptrType , helper.idxType , helper.ptrType },
442
444
loc,
@@ -501,7 +503,7 @@ struct GpuToGpuOcl final : gc::impl::GpuToGpuOclBase<GpuToGpuOcl> {
501
503
assert (mod);
502
504
OpBuilder rewriter (mod.getBody (), mod.getBody ()->end ());
503
505
auto destruct = rewriter.create <LLVM::LLVMFuncOp>(
504
- mod.getLoc (), GC_GPU_OCL_MOD_DESTRUCTOR ,
506
+ mod.getLoc (), GPU_OCL_MOD_DESTRUCTOR ,
505
507
LLVM::LLVMFunctionType::get (helper.voidType , {}),
506
508
LLVM::Linkage::External);
507
509
auto loc = destruct.getLoc ();
0 commit comments