Skip to content

Commit c12cb0c

Browse files
authored
[mlir][python] fix value-builder generation for snake_case ops (#135302)
Ops that are already snake case (like [`ROCDL_wmma_*` ops](https://github.com/makslevental/llvm-project/blob/66b0b0466bbd995146aadaf2cd18de5476c19941/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td#L411)) produce python "value-builders" that collide with the class names: ```python class wmma_bf16_16x16x16_bf16(_ods_ir.OpView): OPERATION_NAME = "rocdl.wmma.bf16.16x16x16.bf16" ... def wmma_bf16_16x16x16_bf16(res, args, *, loc=None, ip=None) -> _ods_ir.Value: return wmma_bf16_16x16x16_bf16(res=res, args=args, loc=loc, ip=ip).result ``` and thus cannot be emitted (because of recursive self-calls). This PR fixes that by affixing `_` to the value builder names. I would've preferred to just rename the ops but that would be a breaking change 🤷.
1 parent dda53be commit c12cb0c

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,9 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
654654

655655
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
656656
// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
657+
658+
// CHECK: class snake_case(_ods_ir.OpView):
659+
// CHECK-LABEL: OPERATION_NAME = "test.snake_case"
660+
def already_snake_case : TestOp<"snake_case"> {}
661+
// CHECK: def snake_case_(*, loc=None, ip=None)
662+
// CHECK: return snake_case(loc=loc, ip=ip)

mlir/test/python/dialects/rocdl.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# RUN: %PYTHON %s | FileCheck %s
22
# This is just a smoke test that the dialect is functional.
3+
from array import array
34

45
from mlir.ir import *
5-
from mlir.dialects import rocdl
6+
from mlir.dialects import rocdl, arith
7+
from mlir.extras import types as T
68

79

810
def constructAndPrintInModule(f):
@@ -18,5 +20,22 @@ def constructAndPrintInModule(f):
1820
# CHECK-LABEL: testSmoke
1921
@constructAndPrintInModule
2022
def testSmoke():
21-
# CHECK: rocdl.barrier
22-
rocdl.BarrierOp()
23+
v_len = 16
24+
f32 = F32Type.get()
25+
# Note: this isn't actually the right type for the intrinsic (should be f16)
26+
# but array doesn't support f16.
27+
v16f32 = T.vector(v_len, f32)
28+
f32_array = array("f", [0.0] * v_len)
29+
a_frag = arith.constant(v16f32, f32_array)
30+
b_frag = arith.constant(v16f32, f32_array)
31+
c_frag = arith.constant(v16f32, f32_array)
32+
false = arith.constant(T.bool(), False)
33+
34+
c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false])
35+
# CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16
36+
print(c_frag)
37+
assert isinstance(c_frag, OpView)
38+
# CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16
39+
c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false])
40+
print(c_frag)
41+
assert isinstance(c_frag, Value)

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,8 @@ static void emitValueBuilder(const Operator &op,
10001000
});
10011001
std::string nameWithoutDialect = sanitizeName(
10021002
op.getOperationName().substr(op.getOperationName().find('.') + 1));
1003+
if (nameWithoutDialect == op.getCppClassName())
1004+
nameWithoutDialect += "_";
10031005
std::string params = llvm::join(valueBuilderParams, ", ");
10041006
std::string args = llvm::join(opBuilderArgs, ", ");
10051007
const char *type =

0 commit comments

Comments
 (0)