-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Fix TileUsingForOp
attr-dict printing/parsing, cleanup assembly format
#72745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…bly format `TileUsingForOp` has an optional Attribute `interchange` which was given in curly braces like this: `{ interchange = [...] }`. The way this was parsed meant that no normal `attr-dict` could be attached to the Op. This patch changes the assembly format of the op to represent the `interchange` Attribute more like other array Attributes in the transform Ops and adds printing/parsing of an optional attr-dict. `transform.structured.tile_using_for %0 [5, 6] interchange [1, 0]`
@llvm/pr-subscribers-mlir-linalg Author: Felix Schneider (ubfx) Changes
Full diff: https://github.com/llvm/llvm-project/pull/72745.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..943621dcfeb1739 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2652,26 +2652,20 @@ SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
ParseResult parseOptionalInterchange(OpAsmParser &parser,
OperationState &result) {
- if (succeeded(parser.parseOptionalLBrace())) {
- if (failed(parser.parseKeyword("interchange")))
- return parser.emitError(parser.getNameLoc()) << "expect `interchange`";
- if (failed(parser.parseEqual()))
- return parser.emitError(parser.getNameLoc()) << "expect `=`";
- result.addAttribute("interchange",
- DenseI64ArrayAttr::parse(parser, Type{}));
- if (failed(parser.parseRBrace()))
- return parser.emitError(parser.getNameLoc()) << "expect `}`";
- }
+ if (succeeded(parser.parseOptionalKeyword("interchange")))
+ result.addAttribute(
+ transform::TileUsingForOp::getInterchangeAttrName(result.name),
+ DenseI64ArrayAttr::parse(parser, Type{}));
return success();
}
void printOptionalInterchange(OpAsmPrinter &p,
ArrayRef<int64_t> interchangeVals) {
if (!interchangeVals.empty()) {
- p << " {interchange = [";
+ p << " interchange [";
llvm::interleaveComma(interchangeVals, p,
[&](int64_t integer) { p << integer; });
- p << "]}";
+ p << "]";
}
}
@@ -2687,6 +2681,7 @@ ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
parseOptionalInterchange(parser, result) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(functionalType))
return ParseResult::failure();
@@ -2721,6 +2716,11 @@ void TileUsingForOp::print(OpAsmPrinter &p) {
/*valueTypes=*/{}, getScalableSizesAttr(),
OpAsmParser::Delimiter::Square);
printOptionalInterchange(p, getInterchange());
+ p.printOptionalAttrDict(
+ (*this)->getAttrs(),
+ /*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()),
+ getScalableSizesAttrName(getOperation()->getName()),
+ getStaticSizesAttrName(getOperation()->getName())});
p << " : ";
p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
}
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 77ce4d0b211f0d7..1c11842fc9b8c30 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -170,7 +170,7 @@ func.func @matvec_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.tile_using_for %0 [5, 6] {interchange = [1, 0]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %1, %loops:2 = transform.structured.tile_using_for %0 [5, 6] interchange [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -199,8 +199,8 @@ func.func @matmul_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:3 = transform.structured.tile_using_for %0 [2000, 3000, 4000] {interchange = [1, 2, 0]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
- %2, %loops_2:3 = transform.structured.tile_using_for %1 [200, 300, 400] {interchange = [1, 0, 2]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %1, %loops:3 = transform.structured.tile_using_for %0 [2000, 3000, 4000] interchange [1, 2, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %2, %loops_2:3 = transform.structured.tile_using_for %1 [200, 300, 400] interchange [1, 0, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
%3, %loops_3:3 = transform.structured.tile_using_for %2 [20, 30, 40] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
|
@llvm/pr-subscribers-mlir Author: Felix Schneider (ubfx) Changes
Full diff: https://github.com/llvm/llvm-project/pull/72745.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..943621dcfeb1739 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2652,26 +2652,20 @@ SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
ParseResult parseOptionalInterchange(OpAsmParser &parser,
OperationState &result) {
- if (succeeded(parser.parseOptionalLBrace())) {
- if (failed(parser.parseKeyword("interchange")))
- return parser.emitError(parser.getNameLoc()) << "expect `interchange`";
- if (failed(parser.parseEqual()))
- return parser.emitError(parser.getNameLoc()) << "expect `=`";
- result.addAttribute("interchange",
- DenseI64ArrayAttr::parse(parser, Type{}));
- if (failed(parser.parseRBrace()))
- return parser.emitError(parser.getNameLoc()) << "expect `}`";
- }
+ if (succeeded(parser.parseOptionalKeyword("interchange")))
+ result.addAttribute(
+ transform::TileUsingForOp::getInterchangeAttrName(result.name),
+ DenseI64ArrayAttr::parse(parser, Type{}));
return success();
}
void printOptionalInterchange(OpAsmPrinter &p,
ArrayRef<int64_t> interchangeVals) {
if (!interchangeVals.empty()) {
- p << " {interchange = [";
+ p << " interchange [";
llvm::interleaveComma(interchangeVals, p,
[&](int64_t integer) { p << integer; });
- p << "]}";
+ p << "]";
}
}
@@ -2687,6 +2681,7 @@ ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
parseOptionalInterchange(parser, result) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(functionalType))
return ParseResult::failure();
@@ -2721,6 +2716,11 @@ void TileUsingForOp::print(OpAsmPrinter &p) {
/*valueTypes=*/{}, getScalableSizesAttr(),
OpAsmParser::Delimiter::Square);
printOptionalInterchange(p, getInterchange());
+ p.printOptionalAttrDict(
+ (*this)->getAttrs(),
+ /*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()),
+ getScalableSizesAttrName(getOperation()->getName()),
+ getStaticSizesAttrName(getOperation()->getName())});
p << " : ";
p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
}
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 77ce4d0b211f0d7..1c11842fc9b8c30 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -170,7 +170,7 @@ func.func @matvec_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.tile_using_for %0 [5, 6] {interchange = [1, 0]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %1, %loops:2 = transform.structured.tile_using_for %0 [5, 6] interchange [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -199,8 +199,8 @@ func.func @matmul_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:3 = transform.structured.tile_using_for %0 [2000, 3000, 4000] {interchange = [1, 2, 0]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
- %2, %loops_2:3 = transform.structured.tile_using_for %1 [200, 300, 400] {interchange = [1, 0, 2]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %1, %loops:3 = transform.structured.tile_using_for %0 [2000, 3000, 4000] interchange [1, 2, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %2, %loops_2:3 = transform.structured.tile_using_for %1 [200, 300, 400] interchange [1, 0, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
%3, %loops_3:3 = transform.structured.tile_using_for %2 [20, 30, 40] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just dropping special handling for interchange
and having it parsed as part of regular attribute dict? This will not change the existing syntax and fix the op.
Done, I wasn't sure about this one because it seems like the majority of the other transform Ops tend towards special printing of the inherent Attributes. Is there a rule for when to prefer one or the other? |
Yes: please consider parsing inherent attribute as part of the discardable attributes dictionary deprecated. |
OK, but in this case we can keep the old assembly format or should we add the special printing/parsing back? |
I suppose we need new assembly, but it should be fine landing this first (this is a bugfix) and working on a better assembly format that is consistent across all operations in the dialect in a separate stage. |
TileUsingForOp
has an optional Attributeinterchange
which was given in curly braces like this:{interchange = [...]}
. The way this was parsed meant that no normalattr-dict
could be attached to the Op.This patch changes the assembly format of the op to represent the
interchange
Attribute more like other array Attributes in the transform Ops and adds printing/parsing of an optional attr-dict.transform.structured.tile_using_for %0 [5, 6] interchange [1, 0]