Skip to content

Commit 51a4f31

Browse files
authored
[mlir:python] Avoid calls to get_op_result_or_results in generated value wrappers (llvm#114491)
If we know the output arity at tablegen time, we can often just call .result or .results directly. This saves almost 1s in a JAX-based LLM benchmark building a mixture of upstream dialects and StableHLO.
1 parent 84f5c85 commit 51a4f31

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
6060
}
6161

6262
// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
63-
// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
63+
// CHECK: return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
6464

6565
// CHECK: @_ods_cext.register_operation(_Dialect)
6666
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -157,7 +157,7 @@ def AttributedOp : TestOp<"attributed_op"> {
157157
}
158158

159159
// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
160-
// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
160+
// CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
161161

162162
// CHECK: @_ods_cext.register_operation(_Dialect)
163163
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -193,7 +193,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
193193
}
194194

195195
// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
196-
// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
196+
// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
197197

198198
// CHECK: @_ods_cext.register_operation(_Dialect)
199199
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -217,7 +217,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
217217
}
218218

219219
// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
220-
// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
220+
// CHECK: return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)
221221

222222
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
223223
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -235,7 +235,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
235235
}
236236

237237
// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
238-
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
238+
// CHECK: return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results
239239

240240
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
241241
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -262,7 +262,7 @@ def EmptyOp : TestOp<"empty">;
262262
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
263263

264264
// CHECK: def empty(*, loc=None, ip=None)
265-
// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
265+
// CHECK: return EmptyOp(loc=loc, ip=ip)
266266

267267
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
268268
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -275,7 +275,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
275275
}
276276

277277
// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
278-
// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
278+
// CHECK: return InferResultTypesImpliedOp(loc=loc, ip=ip).results
279279

280280
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
281281
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -288,7 +288,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
288288
}
289289

290290
// CHECK: def infer_result_types_op(*, loc=None, ip=None)
291-
// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
291+
// CHECK: return InferResultTypesOp(loc=loc, ip=ip).results
292292

293293
// CHECK: @_ods_cext.register_operation(_Dialect)
294294
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -326,7 +326,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
326326
}
327327

328328
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
329-
// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
329+
// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
330330

331331
// CHECK: @_ods_cext.register_operation(_Dialect)
332332
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -357,7 +357,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
357357
}
358358

359359
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
360-
// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
360+
// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
361361

362362
// CHECK: @_ods_cext.register_operation(_Dialect)
363363
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -389,7 +389,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
389389
}
390390

391391
// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
392-
// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
392+
// CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
393393

394394
// CHECK: @_ods_cext.register_operation(_Dialect)
395395
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -446,7 +446,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
446446
}
447447

448448
// CHECK: def python_keyword(in_, *, loc=None, ip=None)
449-
// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
449+
// CHECK: return PythonKeywordOp(in_=in_, loc=loc, ip=ip)
450450

451451
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
452452
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -460,7 +460,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
460460
}
461461

462462
// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
463-
// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
463+
// CHECK: return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)
464464

465465
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
466466
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -497,7 +497,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
497497
}
498498

499499
// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
500-
// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
500+
// CHECK: return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
501501

502502
// CHECK: @_ods_cext.register_operation(_Dialect)
503503
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -563,7 +563,7 @@ def SimpleOp : TestOp<"simple"> {
563563
}
564564

565565
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
566-
// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
566+
// CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
567567

568568
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
569569
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -590,7 +590,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
590590
}
591591

592592
// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
593-
// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
593+
// CHECK: return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
594594

595595
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
596596
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -613,7 +613,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
613613
}
614614

615615
// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
616-
// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
616+
// CHECK: return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
617617

618618
// CHECK: @_ods_cext.register_operation(_Dialect)
619619
// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -622,7 +622,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
622622
}
623623

624624
// CHECK: def _123with__special_characters(*, loc=None, ip=None)
625-
// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
625+
// CHECK: return WithSpecialCharactersOp(loc=loc, ip=ip)
626626

627627
// CHECK: @_ods_cext.register_operation(_Dialect)
628628
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -637,4 +637,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
637637
}
638638

639639
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
640-
// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
640+
// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,11 @@ constexpr const char *regionAccessorTemplate = R"Py(
271271
)Py";
272272

273273
constexpr const char *valueBuilderTemplate = R"Py(
274+
def {0}({2}) -> {4}:
275+
return {1}({3}){5}
276+
)Py";
277+
278+
constexpr const char *valueBuilderVariadicTemplate = R"Py(
274279
def {0}({2}) -> {4}:
275280
return _get_op_result_or_op_results({1}({3}))
276281
)Py";
@@ -992,15 +997,29 @@ static void emitValueBuilder(const Operator &op,
992997
auto lhs = *llvm::split(arg, "=").begin();
993998
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
994999
});
995-
std::string nameWithoutDialect =
996-
op.getOperationName().substr(op.getOperationName().find('.') + 1);
997-
os << formatv(
998-
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
999-
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
1000-
llvm::join(opBuilderArgs, ", "),
1000+
std::string nameWithoutDialect = sanitizeName(
1001+
op.getOperationName().substr(op.getOperationName().find('.') + 1));
1002+
std::string params = llvm::join(valueBuilderParams, ", ");
1003+
std::string args = llvm::join(opBuilderArgs, ", ");
1004+
const char *type =
10011005
(op.getNumResults() > 1
10021006
? "_Sequence[_ods_ir.Value]"
1003-
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
1007+
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
1008+
if (op.getNumVariableLengthResults() > 0) {
1009+
os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
1010+
op.getCppClassName(), params, args, type);
1011+
} else {
1012+
const char *results;
1013+
if (op.getNumResults() == 0) {
1014+
results = "";
1015+
} else if (op.getNumResults() == 1) {
1016+
results = ".result";
1017+
} else {
1018+
results = ".results";
1019+
}
1020+
os << formatv(valueBuilderTemplate, nameWithoutDialect,
1021+
op.getCppClassName(), params, args, type, results);
1022+
}
10041023
}
10051024

10061025
/// Emits bindings for a specific Op to the given output stream.

0 commit comments

Comments
 (0)