Skip to content

Commit d9651f1

Browse files
Added test, fixes
1 parent c4d48f9 commit d9651f1

File tree

8 files changed

+131
-4
lines changed

8 files changed

+131
-4
lines changed

include/gc/Transforms/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
9494
];
9595
}
9696

97+
def AddContextArg : Pass<"add-ctx-arg", "func::FuncOp"> {
98+
let summary = "Add a context argument.";
99+
let description = [{
100+
Add a new memref argument to the function, that could be used to pass some context.
101+
}];
102+
}
103+
97104
def GpuToGpuOcl : Pass<"gpu-to-gpuocl", "ModuleOp"> {
98105
let summary = "Convert the GPU operations to GpuOclRuntime calls.";
99106
let description = [{
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===-- AddContextArg.cpp - Add context argument ----------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "mlir/Conversion/Passes.h"
9+
#include "mlir/Dialect/Func/IR/FuncOps.h"
10+
11+
namespace mlir::gc {
12+
#define GEN_PASS_DECL_ADDCONTEXTARG
13+
#define GEN_PASS_DEF_ADDCONTEXTARG
14+
#include "gc/Transforms/Passes.h.inc"
15+
} // namespace mlir::gc
16+
17+
using namespace mlir;
18+
19+
namespace {
20+
struct AddContextArg final : gc::impl::AddContextArgBase<AddContextArg> {
21+
void runOnOperation() override {
22+
auto func = getOperation();
23+
auto funcType = func.getFunctionType();
24+
auto argTypes = llvm::to_vector<8>(funcType.getInputs());
25+
auto resultTypes = llvm::to_vector<1>(funcType.getResults());
26+
auto ctx = func->getContext();
27+
auto newArgType = MemRefType::get({}, IntegerType::get(ctx, 8));
28+
argTypes.emplace_back(newArgType);
29+
auto newFuncType = FunctionType::get(ctx, argTypes, resultTypes);
30+
func.setType(newFuncType);
31+
32+
if (func.getBody().hasOneBlock()) {
33+
func.getBody().front().addArgument(newArgType, func.getLoc());
34+
}
35+
36+
// Find all function calls and append the last argument of the current
37+
// function to the call.
38+
func.walk([&](func::CallOp call) {
39+
auto args = llvm::to_vector<8>(call.getOperands());
40+
args.emplace_back(func.getArgument(func.getNumArguments() - 1));
41+
call->setOperands(args);
42+
});
43+
}
44+
};
45+
} // namespace

lib/gc/Transforms/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
gc_add_mlir_library(GcGpuPasses
2+
AddContextArg.cpp
23
GpuToGpuOcl.cpp
34
LinalgToXeGPU.cpp
45
Pipeline.cpp

lib/gc/Transforms/GPU/GpuToGpuOcl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
316316
loc, getFuncName,
317317
LLVM::LLVMFunctionType::get(helper.ptrType, {helper.ptrType}),
318318
LLVM::Linkage::Internal);
319+
function.setAlwaysInline(true);
319320
rewriter.setInsertionPointToStart(function.addEntryBlock(rewriter));
320321

321322
auto ptr = mod.lookupSymbol<LLVM::GlobalOp>(str("Ptr"));

lib/gc/Transforms/GPU/Pipeline.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
namespace mlir::gc {
4040

4141
void populateGPUPipeline(mlir::OpPassManager &pm) {
42+
// Add an argument for the GPU context
43+
pm.addNestedPass<func::FuncOp>(createAddContextArg());
44+
4245
pm.addNestedPass<func::FuncOp>(createIterativeTilingAndFusion());
4346

4447
pm.addPass(bufferization::createEmptyTensorEliminationPass());
@@ -76,7 +79,8 @@ void populateGPUPipeline(mlir::OpPassManager &pm) {
7679
pm.addNestedPass<func::FuncOp>(createGpuMapParallelLoopsPass());
7780
pm.addNestedPass<func::FuncOp>(createParallelLoopToGpuPass());
7881

79-
pm.addNestedPass<func::FuncOp>(imex::createInsertGPUAllocsPass("opencl"));
82+
// Temporary disabled until #344 is implemented
83+
// pm.addNestedPass<func::FuncOp>(imex::createInsertGPUAllocsPass("opencl"));
8084
pm.addPass(createGpuKernelOutliningPass());
8185
pm.addPass(createCanonicalizerPass());
8286
pm.addPass(imex::createSetSPIRVCapabilitiesPass());
@@ -95,15 +99,16 @@ void populateGPUPipeline(mlir::OpPassManager &pm) {
9599
pm.addNestedPass<func::FuncOp>(LLVM::createRequestCWrappersPass());
96100
pm.addPass(imex::createSerializeSPIRVPass());
97101
pm.addPass(createConvertVectorToSCFPass());
98-
pm.addPass(imex::createConvertGPUToGPUXPass());
102+
// pm.addPass(imex::createConvertGPUToGPUXPass());
99103
pm.addPass(createConvertSCFToCFPass());
100104
pm.addPass(createConvertControlFlowToLLVMPass());
101105
pm.addPass(createConvertVectorToLLVMPass());
102106
pm.addPass(createConvertIndexToLLVMPass());
103107
pm.addPass(createArithToLLVMConversionPass());
104108
pm.addPass(createConvertFuncToLLVMPass());
105109
pm.addPass(createConvertMathToLLVMPass());
106-
pm.addPass(imex::createConvertGPUXToLLVMPass());
110+
// pm.addPass(imex::createConvertGPUXToLLVMPass());
111+
pm.addPass(createGpuToGpuOcl());
107112
pm.addPass(createConvertIndexToLLVMPass());
108113
pm.addPass(memref::createExpandStridedMetadataPass());
109114
pm.addPass(createLowerAffinePass());
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# GPUX is currently disabled
2+
config.unsupported = True
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: gc-opt %s --gc-gpu-pipeline | FileCheck %s
2+
3+
module @test {
4+
func.func @entry(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) {
5+
%0 = bufferization.to_tensor %arg0 restrict : memref<32x32xf32>
6+
%1 = bufferization.to_tensor %arg1 restrict : memref<32x32xf32>
7+
%2 = tensor.empty() : tensor<32x32xf32>
8+
%3 = linalg.add ins(%1, %0 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%2 : tensor<32x32xf32>) -> tensor<32x32xf32>
9+
bufferization.materialize_in_destination %3 in restrict writable %arg2 : (tensor<32x32xf32>, memref<32x32xf32>) -> ()
10+
return
11+
}
12+
}
13+
14+
// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_SPIRV
15+
// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_Name
16+
// CHECK: llvm.mlir.global internal @gcGpuOclKernel_entry_kernel_Ptr
17+
18+
// CHECK: llvm.func internal @createGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr
19+
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
20+
// CHECK: [[ZERO:%.+]] = llvm.mlir.zero
21+
// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64
22+
// CHECK: [[NEW_PTR:%.+]] = llvm.call @gcGpuOclKernelCreate([[CTX]]
23+
// CHECK: [[CMPXCHG:%.+]] = llvm.cmpxchg [[PTR_ADDR]], [[ZERO]], [[NEW_PTR]]
24+
// CHECK: [[FLAG:%.+]] = llvm.extractvalue [[CMPXCHG]][1]
25+
// CHECK: llvm.cond_br [[FLAG]], [[BB1:\^.+]], [[BB2:\^.+]]
26+
// CHECK: [[BB1]]:
27+
// CHECK: llvm.return [[NEW_PTR]]
28+
// CHECK: [[BB2]]:
29+
// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]]
30+
// CHECK: llvm.store [[NEW_PTR]], [[ARRAY]]
31+
// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]])
32+
// CHECK: [[OLD_PTR:%.+]] = llvm.extractvalue [[CMPXCHG]][0]
33+
// CHECK: llvm.return [[OLD_PTR]]
34+
35+
// CHECK: llvm.func internal @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline}
36+
// CHECK: [[ZERO:%.+]] = llvm.mlir.zero
37+
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
38+
// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]]
39+
// CHECK: [[ICMP:%.+]] = llvm.icmp "eq" [[PTR]], [[ZERO]]
40+
// CHECK: llvm.cond_br [[ICMP]], [[BB1:\^.+]], [[BB2:\^.+]]
41+
// CHECK: [[BB1]]:
42+
// CHECK: [[NEW_PTR:%.+]] = llvm.call @createGcGpuOclKernel_entry_kernel([[CTX]])
43+
// CHECK: llvm.return [[NEW_PTR]]
44+
// CHECK: [[BB2]]:
45+
// CHECK: llvm.return [[PTR]]
46+
47+
// CHECK: llvm.func @entry
48+
// CHECK: [[KERNEL:%.+]] = llvm.call @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]) : (!llvm.ptr) -> !llvm.ptr
49+
// CHECK: llvm.call @gcGpuOclKernelLaunch([[CTX]], [[KERNEL]],
50+
51+
// CHECK: llvm.func @gcGpuOclKernelCreate
52+
// CHECK: llvm.func @gcGpuOclKernelDestroy
53+
// CHECK: llvm.func @gcGpuOclKernelLaunch
54+
55+
56+
// CHECK: llvm.func @gcGpuOclModuleDestructor()
57+
// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64
58+
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
59+
// CHECK: llvm.fence acquire
60+
// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]]
61+
// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]]
62+
// CHECK: llvm.store [[PTR]], [[ARRAY]]
63+
// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]])
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
if not config.gc_use_imex:
2-
config.unsupported = True
2+
config.unsupported = True
3+
else:
4+
# FIXME: Enable when the GPU runner is implemented.
5+
config.excludes = ['mlp.mlir']

0 commit comments

Comments
 (0)