Skip to content

[mlir][python] fix value builders #68764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/_ods_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class LocalOpView(mixin_cls, parent_opview_cls):
) from e
LocalOpView.__name__ = parent_opview_cls.__name__
LocalOpView.__qualname__ = parent_opview_cls.__qualname__
LocalOpView.__has_mixin__ = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about something like LocalOpView.__ods_base__ = parent_opview_cls + generating getattr(OpViewCls, "__ods_base__", OpViewCls)? then you don't have to use __base__ (which I had to look up, and apparently is slightly different from __bases__[0])

return LocalOpView

return class_decorator
Expand Down
70 changes: 46 additions & 24 deletions mlir/test/mlir-tblgen/op-python-bindings.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
}

// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: op = AttrSizedOperandsOp.__base__ if getattr(AttrSizedOperandsOp, "__has_mixin__", False) else AttrSizedOperandsOp
// CHECK: return op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
Expand Down Expand Up @@ -108,8 +109,8 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
}

// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))

// CHECK: op = AttrSizedResultsOp.__base__ if getattr(AttrSizedResultsOp, "__has_mixin__", False) else AttrSizedResultsOp
// CHECK: return _get_op_result_or_op_results(op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOp(_ods_ir.OpView):
Expand Down Expand Up @@ -158,7 +159,8 @@ def AttributedOp : TestOp<"attributed_op"> {
}

// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
// CHECK: op = AttributedOp.__base__ if getattr(AttributedOp, "__has_mixin__", False) else AttributedOp
// CHECK: return op(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
Expand Down Expand Up @@ -194,7 +196,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
}

// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
// CHECK: op = AttributedOpWithOperands.__base__ if getattr(AttributedOpWithOperands, "__has_mixin__", False) else AttributedOpWithOperands
// CHECK: return op(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
Expand All @@ -218,7 +221,8 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
}

// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
// CHECK: op = DefaultValuedAttrsOp.__base__ if getattr(DefaultValuedAttrsOp, "__has_mixin__", False) else DefaultValuedAttrsOp
// CHECK: return op(arr=arr, unsupported=unsupported, loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
Expand All @@ -236,7 +240,8 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
}

// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
// CHECK: op = DeriveResultTypesOp.__base__ if getattr(DeriveResultTypesOp, "__has_mixin__", False) else DeriveResultTypesOp
// CHECK: return _get_op_result_or_op_results(op(type_=type_, loc=loc, ip=ip))

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
Expand All @@ -246,7 +251,8 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
}

// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
// CHECK: op = DeriveResultTypesVariadicOp.__base__ if getattr(DeriveResultTypesVariadicOp, "__has_mixin__", False) else DeriveResultTypesVariadicOp
// CHECK: return _get_op_result_or_op_results(op(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
Expand All @@ -263,7 +269,8 @@ def EmptyOp : TestOp<"empty">;
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))

// CHECK: def empty(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
// CHECK: op = EmptyOp.__base__ if getattr(EmptyOp, "__has_mixin__", False) else EmptyOp
// CHECK: return op(loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
Expand All @@ -276,7 +283,8 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
}

// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
// CHECK: op = InferResultTypesImpliedOp.__base__ if getattr(InferResultTypesImpliedOp, "__has_mixin__", False) else InferResultTypesImpliedOp
// CHECK: return _get_op_result_or_op_results(op(loc=loc, ip=ip))

// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
Expand All @@ -289,7 +297,8 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
}

// CHECK: def infer_result_types_op(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
// CHECK: op = InferResultTypesOp.__base__ if getattr(InferResultTypesOp, "__has_mixin__", False) else InferResultTypesOp
// CHECK: return _get_op_result_or_op_results(op(loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
Expand Down Expand Up @@ -327,7 +336,8 @@ def MissingNamesOp : TestOp<"missing_names"> {
}

// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
// CHECK: op = MissingNamesOp.__base__ if getattr(MissingNamesOp, "__has_mixin__", False) else MissingNamesOp
// CHECK: return _get_op_result_or_op_results(op(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
Expand Down Expand Up @@ -358,7 +368,8 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
}

// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
// CHECK: op = OneOptionalOperandOp.__base__ if getattr(OneOptionalOperandOp, "__has_mixin__", False) else OneOptionalOperandOp
// CHECK: return op(non_optional=non_optional, optional=optional, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
Expand Down Expand Up @@ -390,7 +401,8 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
}

// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
// CHECK: op = OneVariadicOperandOp.__base__ if getattr(OneVariadicOperandOp, "__has_mixin__", False) else OneVariadicOperandOp
// CHECK: return op(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
Expand Down Expand Up @@ -423,7 +435,8 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
}

// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
// CHECK: op = OneVariadicResultOp.__base__ if getattr(OneVariadicResultOp, "__has_mixin__", False) else OneVariadicResultOp
// CHECK: return _get_op_result_or_op_results(op(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
Expand All @@ -447,7 +460,8 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
}

// CHECK: def python_keyword(in_, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
// CHECK: op = PythonKeywordOp.__base__ if getattr(PythonKeywordOp, "__has_mixin__", False) else PythonKeywordOp
// CHECK: return op(in_=in_, loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
Expand All @@ -461,7 +475,8 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
}

// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
// CHECK: op = SameResultsOp.__base__ if getattr(SameResultsOp, "__has_mixin__", False) else SameResultsOp
// CHECK: return _get_op_result_or_op_results(op(in1=in1, in2=in2, loc=loc, ip=ip))

// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
Expand All @@ -471,7 +486,8 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
}

// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
// CHECK: op = SameResultsVariadicOp.__base__ if getattr(SameResultsVariadicOp, "__has_mixin__", False) else SameResultsVariadicOp
// CHECK: return _get_op_result_or_op_results(op(res=res, in1=in1, in2=in2, loc=loc, ip=ip))


// CHECK: @_ods_cext.register_operation(_Dialect)
Expand All @@ -498,7 +514,8 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
}

// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: op = SameVariadicOperandSizeOp.__base__ if getattr(SameVariadicOperandSizeOp, "__has_mixin__", False) else SameVariadicOperandSizeOp
// CHECK: return op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
Expand All @@ -524,7 +541,8 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
}

// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: op = SameVariadicResultSizeOp.__base__ if getattr(SameVariadicResultSizeOp, "__has_mixin__", False) else SameVariadicResultSizeOp
// CHECK: return _get_op_result_or_op_results(op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SimpleOp(_ods_ir.OpView):
Expand Down Expand Up @@ -564,7 +582,8 @@ def SimpleOp : TestOp<"simple"> {
}

// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
// CHECK: op = SimpleOp.__base__ if getattr(SimpleOp, "__has_mixin__", False) else SimpleOp
// CHECK: return _get_op_result_or_op_results(op(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))

// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
Expand All @@ -591,7 +610,8 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
}

// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
// CHECK: op = VariadicAndNormalRegionOp.__base__ if getattr(VariadicAndNormalRegionOp, "__has_mixin__", False) else VariadicAndNormalRegionOp
// CHECK: return op(num_variadic=num_variadic, loc=loc, ip=ip)

// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
Expand All @@ -614,7 +634,8 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
}

// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
// CHECK: op = VariadicRegionOp.__base__ if getattr(VariadicRegionOp, "__has_mixin__", False) else VariadicRegionOp
// CHECK: return op(num_variadic=num_variadic, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
Expand All @@ -629,4 +650,5 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
}

// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
// CHECK: op = WithSuccessorsOp.__base__ if getattr(WithSuccessorsOp, "__has_mixin__", False) else WithSuccessorsOp
// CHECK: return op(successor=successor, successors=successors, loc=loc, ip=ip)
13 changes: 13 additions & 0 deletions mlir/test/python/dialects/arith_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ def testFastMathFlags():
)
# CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
print(r)


# CHECK-LABEL: TEST: testArithValueBuilder
@run
def testArithValueBuilder():
with Context() as ctx, Location.unknown():
module = Module.create()
f32_t = F32Type.get()

with InsertionPoint(module.body):
a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
Copy link
Contributor Author

@makslevental makslevental Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arith.ConstantOp has an "ext" so this will call (in the generated value builder) the base rather than the mixin builder.

# CHECK: %cst = arith.constant 4.242000e+01 : f32
print(a)
14 changes: 11 additions & 3 deletions mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ except ImportError:
_ods_ext_module = None

import builtins
from typing import Sequence as _Sequence, Union as _Union
from typing import Sequence as _Sequence

)Py";

Expand Down Expand Up @@ -269,7 +269,14 @@ constexpr const char *regionAccessorTemplate = R"Py(

constexpr const char *valueBuilderTemplate = R"Py(
def {0}({2}) -> {4}:
return _get_op_result_or_op_results({1}({3}))
op = {1}.__base__ if getattr({1}, "__has_mixin__", False) else {1}
return _get_op_result_or_op_results(op({3}))
)Py";

constexpr const char *valueBuilderNoResultsTemplate = R"Py(
def {0}({2}) -> {4}:
op = {1}.__base__ if getattr({1}, "__has_mixin__", False) else {1}
return op({3})
)Py";

static llvm::cl::OptionCategory
Expand Down Expand Up @@ -1009,7 +1016,8 @@ static void emitValueBuilder(const Operator &op,
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
os << llvm::formatv(
valueBuilderTemplate,
op.getNumResults() > 0 ? valueBuilderTemplate
: valueBuilderNoResultsTemplate,
Comment on lines +1019 to +1020
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the if len(op.results) > 0 case be removed from _get_op_result_or_op_results with this? it could have a slightly more specific return type

// Drop dialect name and then sanitize again (to catch e.g. func.return).
sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
Expand Down