-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir:python] Avoid calls to get_op_result_or_results in generated value wrappers #114491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Peter Hawkins (hawkinsp) Changeswrappers. If we know the output arity at tablegen time, we can often just call .result or .results directly. This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO. Full diff: https://github.com/llvm/llvm-project/pull/114491.diff 2 Files Affected:
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index ba85cb8406b31a..632046389e12cf 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
}
// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -157,7 +157,7 @@ def AttributedOp : TestOp<"attributed_op"> {
}
// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -193,7 +193,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
}
// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -217,7 +217,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
}
// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK: return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -235,7 +235,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
}
// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK: return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -262,7 +262,7 @@ def EmptyOp : TestOp<"empty">;
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
// CHECK: def empty(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK: return EmptyOp(loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -275,7 +275,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
}
// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK: return InferResultTypesImpliedOp(loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -288,7 +288,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
}
// CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK: return InferResultTypesOp(loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -326,7 +326,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
}
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -357,7 +357,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
}
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -389,7 +389,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
}
// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -446,7 +446,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
}
// CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK: return PythonKeywordOp(in_=in_, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -460,7 +460,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
}
// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK: return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -497,7 +497,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
}
// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -563,7 +563,7 @@ def SimpleOp : TestOp<"simple"> {
}
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -590,7 +590,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
}
// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -613,7 +613,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
}
// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -622,7 +622,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
}
// CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+// CHECK: return WithSpecialCharactersOp(loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -637,4 +637,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
}
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
+// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0c5c936f5addee..d243239433b538 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -271,6 +271,11 @@ constexpr const char *regionAccessorTemplate = R"Py(
)Py";
constexpr const char *valueBuilderTemplate = R"Py(
+def {0}({2}) -> {4}:
+ return {1}({3}){5}
+)Py";
+
+constexpr const char *valueBuilderVariadicTemplate = R"Py(
def {0}({2}) -> {4}:
return _get_op_result_or_op_results({1}({3}))
)Py";
@@ -992,15 +997,29 @@ static void emitValueBuilder(const Operator &op,
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
- std::string nameWithoutDialect =
- op.getOperationName().substr(op.getOperationName().find('.') + 1);
- os << formatv(
- valueBuilderTemplate, sanitizeName(nameWithoutDialect),
- op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
- llvm::join(opBuilderArgs, ", "),
- (op.getNumResults() > 1
- ? "_Sequence[_ods_ir.Value]"
- : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
+ std::string nameWithoutDialect = sanitizeName(
+ op.getOperationName().substr(op.getOperationName().find('.') + 1));
+ std::string params = llvm::join(valueBuilderParams, ", ");
+ std::string args = llvm::join(opBuilderArgs, ", ");
+ const char *type =
+ (op.getNumResults() > 1
+ ? "_Sequence[_ods_ir.Value]"
+ : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
+ if (op.getNumVariableLengthResults() > 0) {
+ os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
+ op.getCppClassName(), params, args, type);
+ } else {
+ const char* results;
+ if (op.getNumResults() == 0) {
+ results = "";
+ } else if (op.getNumResults() == 1) {
+ results = ".result";
+ } else {
+ results = ".results";
+ }
+ os << formatv(valueBuilderTemplate, nameWithoutDialect,
+ op.getCppClassName(), params, args, type, results);
+ }
}
/// Emits bindings for a specific Op to the given output stream.
|
@llvm/pr-subscribers-mlir-core Author: Peter Hawkins (hawkinsp) Changeswrappers. If we know the output arity at tablegen time, we can often just call .result or .results directly. This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO. Full diff: https://github.com/llvm/llvm-project/pull/114491.diff 2 Files Affected:
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index ba85cb8406b31a..632046389e12cf 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
}
// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -157,7 +157,7 @@ def AttributedOp : TestOp<"attributed_op"> {
}
// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -193,7 +193,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
}
// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -217,7 +217,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
}
// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK: return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -235,7 +235,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
}
// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK: return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -262,7 +262,7 @@ def EmptyOp : TestOp<"empty">;
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
// CHECK: def empty(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK: return EmptyOp(loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -275,7 +275,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
}
// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK: return InferResultTypesImpliedOp(loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -288,7 +288,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
}
// CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK: return InferResultTypesOp(loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -326,7 +326,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
}
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -357,7 +357,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
}
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -389,7 +389,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
}
// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -446,7 +446,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
}
// CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK: return PythonKeywordOp(in_=in_, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -460,7 +460,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
}
// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK: return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -497,7 +497,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
}
// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -563,7 +563,7 @@ def SimpleOp : TestOp<"simple"> {
}
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -590,7 +590,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
}
// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -613,7 +613,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
}
// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -622,7 +622,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
}
// CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+// CHECK: return WithSpecialCharactersOp(loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -637,4 +637,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
}
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
+// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0c5c936f5addee..d243239433b538 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -271,6 +271,11 @@ constexpr const char *regionAccessorTemplate = R"Py(
)Py";
constexpr const char *valueBuilderTemplate = R"Py(
+def {0}({2}) -> {4}:
+ return {1}({3}){5}
+)Py";
+
+constexpr const char *valueBuilderVariadicTemplate = R"Py(
def {0}({2}) -> {4}:
return _get_op_result_or_op_results({1}({3}))
)Py";
@@ -992,15 +997,29 @@ static void emitValueBuilder(const Operator &op,
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
- std::string nameWithoutDialect =
- op.getOperationName().substr(op.getOperationName().find('.') + 1);
- os << formatv(
- valueBuilderTemplate, sanitizeName(nameWithoutDialect),
- op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
- llvm::join(opBuilderArgs, ", "),
- (op.getNumResults() > 1
- ? "_Sequence[_ods_ir.Value]"
- : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
+ std::string nameWithoutDialect = sanitizeName(
+ op.getOperationName().substr(op.getOperationName().find('.') + 1));
+ std::string params = llvm::join(valueBuilderParams, ", ");
+ std::string args = llvm::join(opBuilderArgs, ", ");
+ const char *type =
+ (op.getNumResults() > 1
+ ? "_Sequence[_ods_ir.Value]"
+ : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
+ if (op.getNumVariableLengthResults() > 0) {
+ os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
+ op.getCppClassName(), params, args, type);
+ } else {
+ const char* results;
+ if (op.getNumResults() == 0) {
+ results = "";
+ } else if (op.getNumResults() == 1) {
+ results = ".result";
+ } else {
+ results = ".results";
+ }
+ os << formatv(valueBuilderTemplate, nameWithoutDialect,
+ op.getCppClassName(), params, args, type, results);
+ }
}
/// Emits bindings for a specific Op to the given output stream.
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
wrappers. If we know the output arity at tablegen time, we can often just call .result or .results directly. This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO..
Sweet!
The 1s was over how much? (seconds seems arbitrary without a denominator somehow?) |
The premerge CI does not check for the python bindings I believe, so no point in waiting for it to complete here. |
It was about 1% of the Python time in this particular benchmark. |
…lue wrappers (llvm#114491) If we know the output arity at tablegen time, we can often just call .result or .results directly. This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO.
…lue wrappers (llvm#114491) If we know the output arity at tablegen time, we can often just call .result or .results directly. This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO.
If we know the output arity at tablegen time, we can often just call .result or .results directly.
This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO.