Skip to content

Commit e7bd104

Browse files
Factor out transform dialect attribute name construction.
1 parent 4ad5c20 commit e7bd104

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def Transform_Dialect : Dialect {
4343
constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName =
4444
"transform.readonly";
4545

46+
/// Above attribute names as `StringAttr`.
47+
StringAttr getConsumedAttrName() const {
48+
return StringAttr::get(getContext(), kArgConsumedAttrName);
49+
}
50+
StringAttr getReadOnlyAttrName() const {
51+
return StringAttr::get(getContext(), kArgReadOnlyAttrName);
52+
}
53+
4654
template <typename DataTy>
4755
const DataTy &getExtraData() const {
4856
return *static_cast<const DataTy *>(

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,6 @@ LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
317317
assert(func1->getParentOp() == func2->getParentOp() &&
318318
"expected func1 and func2 to be in the same parent op");
319319

320-
MLIRContext *context = func1->getContext();
321-
auto consumedName = StringAttr::get(
322-
context, transform::TransformDialect::kArgConsumedAttrName);
323-
auto readOnlyName = StringAttr::get(
324-
context, transform::TransformDialect::kArgReadOnlyAttrName);
325-
326320
// Check that function signatures match.
327321
if (func1.getFunctionType() != func2.getFunctionType()) {
328322
return func1.emitError()
@@ -331,6 +325,10 @@ LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
331325
}
332326

333327
// Check and merge argument attributes.
328+
MLIRContext *context = func1->getContext();
329+
auto td = context->getLoadedDialect<transform::TransformDialect>();
330+
StringAttr consumedName = td->getConsumedAttrName();
331+
StringAttr readOnlyName = td->getReadOnlyAttrName();
334332
for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
335333
bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
336334
bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;

0 commit comments

Comments
 (0)