Skip to content

Commit 2137738

Browse files
committed
[mlir][python] fix value binders
1 parent 255f826 commit 2137738

File tree

4 files changed

+71
-27
lines changed

4 files changed

+71
-27
lines changed

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class LocalOpView(mixin_cls, parent_opview_cls):
7171
) from e
7272
LocalOpView.__name__ = parent_opview_cls.__name__
7373
LocalOpView.__qualname__ = parent_opview_cls.__qualname__
74+
LocalOpView.__has_mixin__ = True
7475
return LocalOpView
7576

7677
return class_decorator

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
6161
}
6262

6363
// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
64-
// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
64+
// CHECK: op = AttrSizedOperandsOp.__base__ if getattr(AttrSizedOperandsOp, "__has_mixin__", False) else AttrSizedOperandsOp
65+
// CHECK: return op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
6566

6667
// CHECK: @_ods_cext.register_operation(_Dialect)
6768
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -108,8 +109,8 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
108109
}
109110

110111
// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
111-
// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
112-
112+
// CHECK: op = AttrSizedResultsOp.__base__ if getattr(AttrSizedResultsOp, "__has_mixin__", False) else AttrSizedResultsOp
113+
// CHECK: return _get_op_result_or_op_results(op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
113114

114115
// CHECK: @_ods_cext.register_operation(_Dialect)
115116
// CHECK: class AttributedOp(_ods_ir.OpView):
@@ -158,7 +159,8 @@ def AttributedOp : TestOp<"attributed_op"> {
158159
}
159160

160161
// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
161-
// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
162+
// CHECK: op = AttributedOp.__base__ if getattr(AttributedOp, "__has_mixin__", False) else AttributedOp
163+
// CHECK: return op(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
162164

163165
// CHECK: @_ods_cext.register_operation(_Dialect)
164166
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -194,7 +196,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
194196
}
195197

196198
// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
197-
// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
199+
// CHECK: op = AttributedOpWithOperands.__base__ if getattr(AttributedOpWithOperands, "__has_mixin__", False) else AttributedOpWithOperands
200+
// CHECK: return op(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
198201

199202
// CHECK: @_ods_cext.register_operation(_Dialect)
200203
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -218,7 +221,8 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
218221
}
219222

220223
// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
221-
// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
224+
// CHECK: op = DefaultValuedAttrsOp.__base__ if getattr(DefaultValuedAttrsOp, "__has_mixin__", False) else DefaultValuedAttrsOp
225+
// CHECK: return op(arr=arr, unsupported=unsupported, loc=loc, ip=ip)
222226

223227
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
224228
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -236,7 +240,8 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
236240
}
237241

238242
// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
239-
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
243+
// CHECK: op = DeriveResultTypesOp.__base__ if getattr(DeriveResultTypesOp, "__has_mixin__", False) else DeriveResultTypesOp
244+
// CHECK: return _get_op_result_or_op_results(op(type_=type_, loc=loc, ip=ip))
240245

241246
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
242247
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -246,7 +251,8 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
246251
}
247252

248253
// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
249-
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
254+
// CHECK: op = DeriveResultTypesVariadicOp.__base__ if getattr(DeriveResultTypesVariadicOp, "__has_mixin__", False) else DeriveResultTypesVariadicOp
255+
// CHECK: return _get_op_result_or_op_results(op(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
250256

251257
// CHECK: @_ods_cext.register_operation(_Dialect)
252258
// CHECK: class EmptyOp(_ods_ir.OpView):
@@ -263,7 +269,8 @@ def EmptyOp : TestOp<"empty">;
263269
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
264270

265271
// CHECK: def empty(*, loc=None, ip=None)
266-
// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
272+
// CHECK: op = EmptyOp.__base__ if getattr(EmptyOp, "__has_mixin__", False) else EmptyOp
273+
// CHECK: return op(loc=loc, ip=ip)
267274

268275
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
269276
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -276,7 +283,8 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
276283
}
277284

278285
// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
279-
// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
286+
// CHECK: op = InferResultTypesImpliedOp.__base__ if getattr(InferResultTypesImpliedOp, "__has_mixin__", False) else InferResultTypesImpliedOp
287+
// CHECK: return _get_op_result_or_op_results(op(loc=loc, ip=ip))
280288

281289
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
282290
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -289,7 +297,8 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
289297
}
290298

291299
// CHECK: def infer_result_types_op(*, loc=None, ip=None)
292-
// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
300+
// CHECK: op = InferResultTypesOp.__base__ if getattr(InferResultTypesOp, "__has_mixin__", False) else InferResultTypesOp
301+
// CHECK: return _get_op_result_or_op_results(op(loc=loc, ip=ip))
293302

294303
// CHECK: @_ods_cext.register_operation(_Dialect)
295304
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -327,7 +336,8 @@ def MissingNamesOp : TestOp<"missing_names"> {
327336
}
328337

329338
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
330-
// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
339+
// CHECK: op = MissingNamesOp.__base__ if getattr(MissingNamesOp, "__has_mixin__", False) else MissingNamesOp
340+
// CHECK: return _get_op_result_or_op_results(op(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
331341

332342
// CHECK: @_ods_cext.register_operation(_Dialect)
333343
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -358,7 +368,8 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
358368
}
359369

360370
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
361-
// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
371+
// CHECK: op = OneOptionalOperandOp.__base__ if getattr(OneOptionalOperandOp, "__has_mixin__", False) else OneOptionalOperandOp
372+
// CHECK: return op(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
362373

363374
// CHECK: @_ods_cext.register_operation(_Dialect)
364375
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -390,7 +401,8 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
390401
}
391402

392403
// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
393-
// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
404+
// CHECK: op = OneVariadicOperandOp.__base__ if getattr(OneVariadicOperandOp, "__has_mixin__", False) else OneVariadicOperandOp
405+
// CHECK: return op(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
394406

395407
// CHECK: @_ods_cext.register_operation(_Dialect)
396408
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -423,7 +435,8 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
423435
}
424436

425437
// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
426-
// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
438+
// CHECK: op = OneVariadicResultOp.__base__ if getattr(OneVariadicResultOp, "__has_mixin__", False) else OneVariadicResultOp
439+
// CHECK: return _get_op_result_or_op_results(op(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
427440

428441
// CHECK: @_ods_cext.register_operation(_Dialect)
429442
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -447,7 +460,8 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
447460
}
448461

449462
// CHECK: def python_keyword(in_, *, loc=None, ip=None)
450-
// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
463+
// CHECK: op = PythonKeywordOp.__base__ if getattr(PythonKeywordOp, "__has_mixin__", False) else PythonKeywordOp
464+
// CHECK: return op(in_=in_, loc=loc, ip=ip)
451465

452466
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
453467
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -461,7 +475,8 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
461475
}
462476

463477
// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
464-
// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
478+
// CHECK: op = SameResultsOp.__base__ if getattr(SameResultsOp, "__has_mixin__", False) else SameResultsOp
479+
// CHECK: return _get_op_result_or_op_results(op(in1=in1, in2=in2, loc=loc, ip=ip))
465480

466481
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
467482
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -471,7 +486,8 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
471486
}
472487

473488
// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
474-
// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
489+
// CHECK: op = SameResultsVariadicOp.__base__ if getattr(SameResultsVariadicOp, "__has_mixin__", False) else SameResultsVariadicOp
490+
// CHECK: return _get_op_result_or_op_results(op(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
475491

476492

477493
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -498,7 +514,8 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
498514
}
499515

500516
// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
501-
// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
517+
// CHECK: op = SameVariadicOperandSizeOp.__base__ if getattr(SameVariadicOperandSizeOp, "__has_mixin__", False) else SameVariadicOperandSizeOp
518+
// CHECK: return op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
502519

503520
// CHECK: @_ods_cext.register_operation(_Dialect)
504521
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -524,7 +541,8 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
524541
}
525542

526543
// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
527-
// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
544+
// CHECK: op = SameVariadicResultSizeOp.__base__ if getattr(SameVariadicResultSizeOp, "__has_mixin__", False) else SameVariadicResultSizeOp
545+
// CHECK: return _get_op_result_or_op_results(op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
528546

529547
// CHECK: @_ods_cext.register_operation(_Dialect)
530548
// CHECK: class SimpleOp(_ods_ir.OpView):
@@ -564,7 +582,8 @@ def SimpleOp : TestOp<"simple"> {
564582
}
565583

566584
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
567-
// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
585+
// CHECK: op = SimpleOp.__base__ if getattr(SimpleOp, "__has_mixin__", False) else SimpleOp
586+
// CHECK: return _get_op_result_or_op_results(op(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
568587

569588
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
570589
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -591,7 +610,8 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
591610
}
592611

593612
// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
594-
// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
613+
// CHECK: op = VariadicAndNormalRegionOp.__base__ if getattr(VariadicAndNormalRegionOp, "__has_mixin__", False) else VariadicAndNormalRegionOp
614+
// CHECK: return op(num_variadic=num_variadic, loc=loc, ip=ip)
595615

596616
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
597617
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -614,7 +634,8 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
614634
}
615635

616636
// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
617-
// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
637+
// CHECK: op = VariadicRegionOp.__base__ if getattr(VariadicRegionOp, "__has_mixin__", False) else VariadicRegionOp
638+
// CHECK: return op(num_variadic=num_variadic, loc=loc, ip=ip)
618639

619640
// CHECK: @_ods_cext.register_operation(_Dialect)
620641
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -629,4 +650,5 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
629650
}
630651

631652
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
632-
// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
653+
// CHECK: op = WithSuccessorsOp.__base__ if getattr(WithSuccessorsOp, "__has_mixin__", False) else WithSuccessorsOp
654+
// CHECK: return op(successor=successor, successors=successors, loc=loc, ip=ip)

mlir/test/python/dialects/arith_dialect.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,16 @@ def testFastMathFlags():
3333
)
3434
# CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
3535
print(r)
36+
37+
38+
# CHECK-LABEL: TEST: testArithValueBuilder
39+
@run
40+
def testArithValueBuilder():
41+
with Context() as ctx, Location.unknown():
42+
module = Module.create()
43+
f32_t = F32Type.get()
44+
45+
with InsertionPoint(module.body):
46+
a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
47+
# CHECK: %cst = arith.constant 4.242000e+01 : f32
48+
print(a)

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ except ImportError:
3939
_ods_ext_module = None
4040
4141
import builtins
42-
from typing import Sequence as _Sequence, Union as _Union
42+
from typing import Sequence as _Sequence
4343
4444
)Py";
4545

@@ -269,7 +269,14 @@ constexpr const char *regionAccessorTemplate = R"Py(
269269

270270
constexpr const char *valueBuilderTemplate = R"Py(
271271
def {0}({2}) -> {4}:
272-
return _get_op_result_or_op_results({1}({3}))
272+
op = {1}.__base__ if getattr({1}, "__has_mixin__", False) else {1}
273+
return _get_op_result_or_op_results(op({3}))
274+
)Py";
275+
276+
constexpr const char *valueBuilderNoResultsTemplate = R"Py(
277+
def {0}({2}) -> {4}:
278+
op = {1}.__base__ if getattr({1}, "__has_mixin__", False) else {1}
279+
return op({3})
273280
)Py";
274281

275282
static llvm::cl::OptionCategory
@@ -1009,7 +1016,8 @@ static void emitValueBuilder(const Operator &op,
10091016
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
10101017
});
10111018
os << llvm::formatv(
1012-
valueBuilderTemplate,
1019+
op.getNumResults() > 0 ? valueBuilderTemplate
1020+
: valueBuilderNoResultsTemplate,
10131021
// Drop dialect name and then sanitize again (to catch e.g. func.return).
10141022
sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
10151023
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),

0 commit comments

Comments
 (0)