Skip to content

[mlir] Fix TileUsingForOp attr-dict printing/parsing #73261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 6, 2023

Conversation

ubfx
Copy link
Member

@ubfx ubfx commented Nov 23, 2023

Reland #72745
fixed the failing test

Original message:
[mlir] Fix TileUsingForOp attr-dict printing/parsing (#72745)
TileUsingForOp has an optional Attribute interchange which was given
in curly braces like this: {interchange = [...]}. The way this was
parsed meant that no normal attr-dict could be attached to the Op.
This patch adds printing / parsing of an attr-dict to the Op and treats
the interchange Attribute as part of that dictionary for now.

@llvmbot
Copy link
Member

llvmbot commented Nov 23, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Felix Schneider (ubfx)

Changes

Reland #72745
fixed the failing test

Original message:
[mlir] Fix TileUsingForOp attr-dict printing/parsing (#72745)
TileUsingForOp has an optional Attribute interchange which was given
in curly braces like this: {interchange = [...]}. The way this was
parsed meant that no normal attr-dict could be attached to the Op.
This patch adds printing / parsing of an attr-dict to the Op and treats
the interchange Attribute as part of that dictionary for now.


Full diff: https://github.com/llvm/llvm-project/pull/73261.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+8-31)
  • (modified) mlir/test/Dialect/Linalg/transform-ops.mlir (+8)
  • (modified) mlir/test/python/dialects/transform_structured_ext.py (+2-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..c8f0806e27a6264 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1819,7 +1819,7 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
   let arguments = (ins TransformHandleTypeInterface:$target,
                    Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
-                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
+                   DefaultValuedOptionalAttr<I64ArrayAttr, "{}">:$interchange,
                    DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
   let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
                       Variadic<TransformHandleTypeInterface>:$loops);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..73de3f22d896f0a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2477,7 +2477,7 @@ void transform::TileUsingForOp::build(
         /*target=*/target,
         /*dynamic_sizes=*/dynamicTileSizes,
         /*static_sizes=*/staticTileSizesAttr,
-        /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
+        /*interchange=*/builder.getI64ArrayAttr(interchange),
         /*scalable_sizes=*/expandedScalableSizes);
 }
 
@@ -2611,7 +2611,8 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
       });
     }
 
-    tilingOptions.setInterchange(getInterchange());
+    tilingOptions.setInterchange(
+        extractFromIntegerArrayAttr<int64_t>(getInterchange()));
     FailureOr<scf::SCFTilingResult> maybeTilingResult =
         tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
     if (failed(maybeTilingResult))
@@ -2648,33 +2649,6 @@ SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
   return results;
 }
 
-// We want to parse `DenseI64ArrayAttr` using the short form without the
-// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
-ParseResult parseOptionalInterchange(OpAsmParser &parser,
-                                     OperationState &result) {
-  if (succeeded(parser.parseOptionalLBrace())) {
-    if (failed(parser.parseKeyword("interchange")))
-      return parser.emitError(parser.getNameLoc()) << "expect `interchange`";
-    if (failed(parser.parseEqual()))
-      return parser.emitError(parser.getNameLoc()) << "expect `=`";
-    result.addAttribute("interchange",
-                        DenseI64ArrayAttr::parse(parser, Type{}));
-    if (failed(parser.parseRBrace()))
-      return parser.emitError(parser.getNameLoc()) << "expect `}`";
-  }
-  return success();
-}
-
-void printOptionalInterchange(OpAsmPrinter &p,
-                              ArrayRef<int64_t> interchangeVals) {
-  if (!interchangeVals.empty()) {
-    p << " {interchange = [";
-    llvm::interleaveComma(interchangeVals, p,
-                          [&](int64_t integer) { p << integer; });
-    p << "]}";
-  }
-}
-
 ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
                                              OperationState &result) {
   OpAsmParser::UnresolvedOperand target;
@@ -2686,7 +2660,7 @@ ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
 
   if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
       parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
-      parseOptionalInterchange(parser, result) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseColonType(functionalType))
     return ParseResult::failure();
 
@@ -2720,7 +2694,10 @@ void TileUsingForOp::print(OpAsmPrinter &p) {
   printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
                         /*valueTypes=*/{}, getScalableSizesAttr(),
                         OpAsmParser::Delimiter::Square);
-  printOptionalInterchange(p, getInterchange());
+  p.printOptionalAttrDict(
+      (*this)->getAttrs(),
+      /*elidedAttrs=*/{getScalableSizesAttrName(getOperation()->getName()),
+                       getStaticSizesAttrName(getOperation()->getName())});
   p << " : ";
   p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
 }
diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir
index e9f044be5b4ed22..4d7c514dcca62d5 100644
--- a/mlir/test/Dialect/Linalg/transform-ops.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops.mlir
@@ -6,6 +6,14 @@ transform.sequence failures(propagate) {
   %0, %1:2 = transform.structured.tile_using_for %arg0 [2, 0, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
 }
 
+// check that the Attributes of `tile_using_for` are preserved through printing
+// and parsing.
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile %arg0 [2, 0, 3] {interchange = [2, 1], test_attr1 = 1 : i64, test_attr2}
+  %0, %1:2 = transform.structured.tile_using_for %arg0 [2, 0, 3] {test_attr1 = 1 : i64, interchange = [2, 1], test_attr2}: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !transform.any_op):
   %0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index c9b7802e1cc4532..1c7552e2bd51a80 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -344,7 +344,7 @@ def testSplit(target):
 @run
 @create_sequence
 def testTileCompact(target):
-    structured.TileUsingForOp(target, sizes=[4, 8], interchange=[0, 1])
+    structured.TileUsingForOp(target, sizes=[4, 8], interchange=Attribute.parse("[0, 1]"))
     # CHECK-LABEL: TEST: testTileCompact
     # CHECK: transform.sequence
     # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 8]
@@ -355,7 +355,7 @@ def testTileCompact(target):
 @create_sequence
 def testTileAttributes(target):
     attr = DenseI64ArrayAttr.get([4, 8])
-    ichange = DenseI64ArrayAttr.get([0, 1])
+    ichange = Attribute.parse("[0, 1]")
     structured.TileUsingForOp(target, sizes=attr, interchange=ichange)
     # CHECK-LABEL: TEST: testTileAttributes
     # CHECK: transform.sequence

Copy link

github-actions bot commented Nov 23, 2023

✅ With the latest revision this PR passed the Python code formatter.

@ubfx ubfx requested a review from makslevental November 24, 2023 13:37
…bly format

`TileUsingForOp` has an optional Attribute `interchange` which was
given in curly braces like this: `{ interchange = [...] }`.
The way this was parsed meant that no normal `attr-dict` could be
attached to the Op.
This patch changes the assembly format of the op to represent the
`interchange` Attribute more like other array Attributes in the
transform Ops and adds printing/parsing of an optional attr-dict.

`transform.structured.tile_using_for %0 [5, 6] interchange [1, 0]`
@ubfx ubfx requested a review from joker-eph December 4, 2023 21:09
@ubfx ubfx changed the title Reland "[mlir] Fix TileUsingForOp attr-dict printing/parsing" [mlir] Fix TileUsingForOp attr-dict printing/parsing Dec 5, 2023
@ubfx ubfx merged commit e07c92a into llvm:main Dec 6, 2023
@ubfx ubfx deleted the forall-assembly branch December 6, 2023 19:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants