-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][python] fix value builders #68764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,3 +33,16 @@ def testFastMathFlags(): | |
) | ||
# CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32 | ||
print(r) | ||
|
||
|
||
# CHECK-LABEL: TEST: testArithValueBuilder | ||
@run | ||
def testArithValueBuilder(): | ||
with Context() as ctx, Location.unknown(): | ||
module = Module.create() | ||
f32_t = F32Type.get() | ||
|
||
with InsertionPoint(module.body): | ||
a = arith.constant(value=FloatAttr.get(f32_t, 42.42)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# CHECK: %cst = arith.constant 4.242000e+01 : f32 | ||
print(a) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ except ImportError: | |
_ods_ext_module = None | ||
|
||
import builtins | ||
from typing import Sequence as _Sequence, Union as _Union | ||
from typing import Sequence as _Sequence | ||
|
||
)Py"; | ||
|
||
|
@@ -269,7 +269,14 @@ constexpr const char *regionAccessorTemplate = R"Py( | |
|
||
constexpr const char *valueBuilderTemplate = R"Py( | ||
def {0}({2}) -> {4}: | ||
return _get_op_result_or_op_results({1}({3})) | ||
op = {1}.__base__ if getattr({1}, "__has_mixin__", False) else {1} | ||
return _get_op_result_or_op_results(op({3})) | ||
)Py"; | ||
|
||
constexpr const char *valueBuilderNoResultsTemplate = R"Py( | ||
def {0}({2}) -> {4}: | ||
op = {1}.__base__ if getattr({1}, "__has_mixin__", False) else {1} | ||
return op({3}) | ||
)Py"; | ||
|
||
static llvm::cl::OptionCategory | ||
|
@@ -1009,7 +1016,8 @@ static void emitValueBuilder(const Operator &op, | |
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str(); | ||
}); | ||
os << llvm::formatv( | ||
valueBuilderTemplate, | ||
op.getNumResults() > 0 ? valueBuilderTemplate | ||
: valueBuilderNoResultsTemplate, | ||
Comment on lines
+1019
to
+1020
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should the |
||
// Drop dialect name and then sanitize again (to catch e.g. func.return). | ||
sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")), | ||
op.getCppClassName(), llvm::join(valueBuilderParams, ", "), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about something like
LocalOpView.__ods_base__ = parent_opview_cls
+ generatinggetattr(OpViewCls, "__ods_base__", OpViewCls)
? then you don't have to use__base__
(which I had to look up, and apparently is slightly different from__bases__[0]
)