Skip to content

Commit 663e9ce

Browse files
akroviakovArtem Kroviakov
andauthored
[Func][GPU] Use SymbolUserOpInterface in func::ConstantOp (#107748)
This PR enables `func::ConstantOp` creation and usage for device functions inside GPU modules. The current main returns error for referencing device functions via `func::ConstantOp`, because during the `ConstantOp` verification it only checks symbols in `ModuleOp` symbol table, which, of course, does not contain device functions that are defined in `GPUModuleOp`. This PR proposes a more general solution. Co-authored-by: Artem Kroviakov <[email protected]>
1 parent aa21ce4 commit 663e9ce

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

mlir/include/mlir/Dialect/Func/IR/FuncOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def CallIndirectOp : Func_Op<"call_indirect", [
183183

184184
def ConstantOp : Func_Op<"constant",
185185
[ConstantLike, Pure,
186+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
186187
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
187188
let summary = "constant";
188189
let description = [{
@@ -216,7 +217,6 @@ def ConstantOp : Func_Op<"constant",
216217
}];
217218

218219
let hasFolder = 1;
219-
let hasVerifier = 1;
220220
}
221221

222222
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Func/IR/FuncOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,13 @@ LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
123123
// ConstantOp
124124
//===----------------------------------------------------------------------===//
125125

126-
LogicalResult ConstantOp::verify() {
126+
LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
127127
StringRef fnName = getValue();
128128
Type type = getType();
129129

130130
// Try to find the referenced function.
131-
auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
131+
auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(
132+
this->getOperation(), StringAttr::get(getContext(), fnName));
132133
if (!fn)
133134
return emitOpError() << "reference to undefined function '" << fnName
134135
<< "'";
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt -test-gpu-rewrite -convert-func-to-llvm %s | FileCheck %s
2+
3+
gpu.module @kernels {
4+
// CHECK-LABEL: @hello
5+
// CHECK-SAME: %[[ARG0:.*]]: f32
6+
func.func @hello(%arg0 : f32) {
7+
%tid_x = gpu.thread_id x
8+
%csti8 = arith.constant 2 : i8
9+
gpu.printf "Hello from %lld, %d, %f\n" %tid_x, %csti8, %arg0 : index, i8, f32
10+
return
11+
}
12+
// CHECK-LABEL: @hello_indirect
13+
gpu.func @hello_indirect() kernel {
14+
%cstf32 = arith.constant 3.0 : f32
15+
// CHECK: %[[DEVICE_FUNC_ADDR:.*]] = llvm.mlir.addressof @hello : !llvm.ptr
16+
%func_ref = func.constant @hello : (f32) -> ()
17+
// CHECK: llvm.call %[[DEVICE_FUNC_ADDR]](%{{.*}}) : !llvm.ptr, (f32) -> ()
18+
func.call_indirect %func_ref(%cstf32) : (f32) -> ()
19+
gpu.return
20+
}
21+
}

0 commit comments

Comments
 (0)