Skip to content

[mlir:python] Avoid calls to get_op_result_or_results in generated value wrappers #114491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2024

Conversation

hawkinsp
Copy link
Contributor

@hawkinsp hawkinsp commented Oct 31, 2024

If we know the output arity at tablegen time, we can often just call .result or .results directly.

This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 31, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2024

@llvm/pr-subscribers-mlir

Author: Peter Hawkins (hawkinsp)

Changes

wrappers.

If we know the output arity at tablegen time, we can often just call .result or .results directly.

This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO.


Full diff: https://github.com/llvm/llvm-project/pull/114491.diff

2 Files Affected:

  • (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+19-19)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+28-9)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index ba85cb8406b31a..632046389e12cf 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
 }
 
 // CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -157,7 +157,7 @@ def AttributedOp : TestOp<"attributed_op"> {
 }
 
 // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK:     return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK:     return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -193,7 +193,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 }
 
 // CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK:   return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -217,7 +217,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
 }
 
 // CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK:   return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -235,7 +235,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
 }
 
 // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK:   return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -262,7 +262,7 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
 // CHECK: def empty(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK:   return EmptyOp(loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -275,7 +275,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
 }
 
 // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK:   return InferResultTypesImpliedOp(loc=loc, ip=ip).results
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -288,7 +288,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
 }
 
 // CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK:   return InferResultTypesOp(loc=loc, ip=ip).results
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -326,7 +326,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
 }
 
 // CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK:   return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -357,7 +357,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
 }
 
 // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK:   return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -389,7 +389,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
 }
 
 // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK:   return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -446,7 +446,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
 }
 
 // CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK:   return PythonKeywordOp(in_=in_, loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -460,7 +460,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
 }
 
 // CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -497,7 +497,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 }
 
 // CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -563,7 +563,7 @@ def SimpleOp : TestOp<"simple"> {
 }
 
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK:   return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
 
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -590,7 +590,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
 }
 
 // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
 
 // CHECK: class VariadicRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -613,7 +613,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
 }
 
 // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -622,7 +622,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
 }
 
 // CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+// CHECK:   return WithSpecialCharactersOp(loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -637,4 +637,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 }
 
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
+// CHECK:   return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0c5c936f5addee..d243239433b538 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -271,6 +271,11 @@ constexpr const char *regionAccessorTemplate = R"Py(
 )Py";
 
 constexpr const char *valueBuilderTemplate = R"Py(
+def {0}({2}) -> {4}:
+  return {1}({3}){5}
+)Py";
+
+constexpr const char *valueBuilderVariadicTemplate = R"Py(
 def {0}({2}) -> {4}:
   return _get_op_result_or_op_results({1}({3}))
 )Py";
@@ -992,15 +997,29 @@ static void emitValueBuilder(const Operator &op,
         auto lhs = *llvm::split(arg, "=").begin();
         return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
       });
-  std::string nameWithoutDialect =
-      op.getOperationName().substr(op.getOperationName().find('.') + 1);
-  os << formatv(
-      valueBuilderTemplate, sanitizeName(nameWithoutDialect),
-      op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
-      llvm::join(opBuilderArgs, ", "),
-      (op.getNumResults() > 1
-           ? "_Sequence[_ods_ir.Value]"
-           : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
+  std::string nameWithoutDialect = sanitizeName(
+      op.getOperationName().substr(op.getOperationName().find('.') + 1));
+  std::string params = llvm::join(valueBuilderParams, ", ");
+  std::string args = llvm::join(opBuilderArgs, ", ");
+  const char *type =
+       (op.getNumResults() > 1
+            ? "_Sequence[_ods_ir.Value]"
+           : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
+  if (op.getNumVariableLengthResults() > 0) {
+    os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
+                  op.getCppClassName(), params, args, type);
+  } else {
+    const char* results;
+    if (op.getNumResults() == 0) {
+      results = "";
+    } else if (op.getNumResults() == 1) {
+      results = ".result";
+    } else {
+      results = ".results";
+    }
+    os << formatv(valueBuilderTemplate, nameWithoutDialect,
+                  op.getCppClassName(), params, args, type, results);
+  }
 }
 
 /// Emits bindings for a specific Op to the given output stream.

@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2024

@llvm/pr-subscribers-mlir-core

Author: Peter Hawkins (hawkinsp)

Changes

wrappers.

If we know the output arity at tablegen time, we can often just call .result or .results directly.

This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO.


Full diff: https://github.com/llvm/llvm-project/pull/114491.diff

2 Files Affected:

  • (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+19-19)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+28-9)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index ba85cb8406b31a..632046389e12cf 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
 }
 
 // CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -157,7 +157,7 @@ def AttributedOp : TestOp<"attributed_op"> {
 }
 
 // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK:     return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK:     return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -193,7 +193,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 }
 
 // CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK:   return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -217,7 +217,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
 }
 
 // CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK:   return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -235,7 +235,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
 }
 
 // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK:   return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -262,7 +262,7 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
 // CHECK: def empty(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK:   return EmptyOp(loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -275,7 +275,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
 }
 
 // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK:   return InferResultTypesImpliedOp(loc=loc, ip=ip).results
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -288,7 +288,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
 }
 
 // CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK:   return InferResultTypesOp(loc=loc, ip=ip).results
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -326,7 +326,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
 }
 
 // CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK:   return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -357,7 +357,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
 }
 
 // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK:   return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -389,7 +389,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
 }
 
 // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK:   return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -446,7 +446,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
 }
 
 // CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK:   return PythonKeywordOp(in_=in_, loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -460,7 +460,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
 }
 
 // CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -497,7 +497,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 }
 
 // CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -563,7 +563,7 @@ def SimpleOp : TestOp<"simple"> {
 }
 
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK:   return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
 
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -590,7 +590,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
 }
 
 // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
 
 // CHECK: class VariadicRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -613,7 +613,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
 }
 
 // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -622,7 +622,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
 }
 
 // CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+// CHECK:   return WithSpecialCharactersOp(loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -637,4 +637,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 }
 
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
+// CHECK:   return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0c5c936f5addee..d243239433b538 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -271,6 +271,11 @@ constexpr const char *regionAccessorTemplate = R"Py(
 )Py";
 
 constexpr const char *valueBuilderTemplate = R"Py(
+def {0}({2}) -> {4}:
+  return {1}({3}){5}
+)Py";
+
+constexpr const char *valueBuilderVariadicTemplate = R"Py(
 def {0}({2}) -> {4}:
   return _get_op_result_or_op_results({1}({3}))
 )Py";
@@ -992,15 +997,29 @@ static void emitValueBuilder(const Operator &op,
         auto lhs = *llvm::split(arg, "=").begin();
         return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
       });
-  std::string nameWithoutDialect =
-      op.getOperationName().substr(op.getOperationName().find('.') + 1);
-  os << formatv(
-      valueBuilderTemplate, sanitizeName(nameWithoutDialect),
-      op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
-      llvm::join(opBuilderArgs, ", "),
-      (op.getNumResults() > 1
-           ? "_Sequence[_ods_ir.Value]"
-           : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
+  std::string nameWithoutDialect = sanitizeName(
+      op.getOperationName().substr(op.getOperationName().find('.') + 1));
+  std::string params = llvm::join(valueBuilderParams, ", ");
+  std::string args = llvm::join(opBuilderArgs, ", ");
+  const char *type =
+       (op.getNumResults() > 1
+            ? "_Sequence[_ods_ir.Value]"
+           : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
+  if (op.getNumVariableLengthResults() > 0) {
+    os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
+                  op.getCppClassName(), params, args, type);
+  } else {
+    const char* results;
+    if (op.getNumResults() == 0) {
+      results = "";
+    } else if (op.getNumResults() == 1) {
+      results = ".result";
+    } else {
+      results = ".results";
+    }
+    os << formatv(valueBuilderTemplate, nameWithoutDialect,
+                  op.getCppClassName(), params, args, type, results);
+  }
 }
 
 /// Emits bindings for a specific Op to the given output stream.

Copy link

github-actions bot commented Oct 31, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

wrappers.

If we know the output arity at tablegen time, we can often just call
.result or .results directly.

This saves almost 1s in a JAX-based LLM benchmark building a mixture of
upstream dialects and StableHLO..
@joker-eph joker-eph changed the title [mlir:python] Avoid calls to get_op_result_or_results in generated value [mlir:python] Avoid calls to get_op_result_or_results in generated value wrappers Oct 31, 2024
@joker-eph
Copy link
Collaborator

Sweet!

This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO.

The 1s was over how much? (seconds seems arbitrary without a denominator somehow?)

@joker-eph
Copy link
Collaborator

The premerge CI does not check for the python bindings I believe, so no point in waiting for it to complete here.

@joker-eph joker-eph merged commit 51a4f31 into llvm:main Oct 31, 2024
6 of 7 checks passed
@hawkinsp
Copy link
Contributor Author

hawkinsp commented Nov 1, 2024

Sweet!

This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO.

The 1s was over how much? (seconds seems arbitrary without a denominator somehow?)

It was about 1% of the Python time in this particular benchmark.

smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
…lue wrappers (llvm#114491)

If we know the output arity at tablegen time, we can often just call
.result or .results directly.

This saves almost 1s in a JAX-based LLM benchmark building a mixture of
upstream dialects and StableHLO.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
…lue wrappers (llvm#114491)

If we know the output arity at tablegen time, we can often just call
.result or .results directly.

This saves almost 1s in a JAX-based LLM benchmark building a mixture of
upstream dialects and StableHLO.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants