Skip to content

[MLIR] Don't drop attached discardable attributes #111261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 7, 2024

Conversation

pashu123
Copy link
Member

@pashu123 pashu123 commented Oct 5, 2024

The creation of pack op was dropping discardable attributes.

@llvmbot
Copy link
Member

llvmbot commented Oct 5, 2024

@llvm/pr-subscribers-mlir-tensor

Author: Prashant Kumar (pashu123)

Changes

The creation of pack op was dropping custom attached attributes.


Full diff: https://github.com/llvm/llvm-project/pull/111261.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+9-4)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+12-1)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index defac8308b9092..659eabd2e93880 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4337,11 +4337,16 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
       dest =
           rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
     }
-    Value newOp = rewriter.create<tensor::PackOp>(
-        loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
-        packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+    auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
+    Value res = clonedPackOp.getResult();
+    rewriter.startOpModification(clonedPackOp);
+    clonedPackOp.getSourceMutable().assign(source);
+    clonedPackOp.getDestMutable().assign(dest);
+    res.setType(dest.getType());
+    rewriter.finalizeOpModification(clonedPackOp);
+
     rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        packOp, packOp.getResult().getType(), newOp);
+        packOp, packOp.getResult().getType(), clonedPackOp);
     return success();
   }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 86754c1c37536d..03ff45380dca9b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2357,7 +2357,7 @@ func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>
   %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
   %tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
   %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
-  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16> 
+  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
   return %unpacked : tensor<224x512xbf16>
 }
 
@@ -2707,3 +2707,14 @@ func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tenso
   %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
   return %0#1 : index
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @pack_dont_drop_attributes(
+// CHECK: tensor.pack {{.*}}  {test_attr = 16 : i64}
+func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
+  %c32_i64 = arith.constant 32 : i64
+  %cst = arith.constant 0.000000e+00 : f16
+  %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr = 16 : i64} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
+  return %pack : tensor<128x?x100x16x1xf16>
+}

@llvmbot
Copy link
Member

llvmbot commented Oct 5, 2024

@llvm/pr-subscribers-mlir

Author: Prashant Kumar (pashu123)

Changes

The creation of pack op was dropping custom attached attributes.


Full diff: https://github.com/llvm/llvm-project/pull/111261.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+9-4)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+12-1)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index defac8308b9092..659eabd2e93880 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4337,11 +4337,16 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
       dest =
           rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
     }
-    Value newOp = rewriter.create<tensor::PackOp>(
-        loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
-        packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+    auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
+    Value res = clonedPackOp.getResult();
+    rewriter.startOpModification(clonedPackOp);
+    clonedPackOp.getSourceMutable().assign(source);
+    clonedPackOp.getDestMutable().assign(dest);
+    res.setType(dest.getType());
+    rewriter.finalizeOpModification(clonedPackOp);
+
     rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        packOp, packOp.getResult().getType(), newOp);
+        packOp, packOp.getResult().getType(), clonedPackOp);
     return success();
   }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 86754c1c37536d..03ff45380dca9b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2357,7 +2357,7 @@ func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>
   %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
   %tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
   %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
-  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16> 
+  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
   return %unpacked : tensor<224x512xbf16>
 }
 
@@ -2707,3 +2707,14 @@ func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tenso
   %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
   return %0#1 : index
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @pack_dont_drop_attributes(
+// CHECK: tensor.pack {{.*}}  {test_attr = 16 : i64}
+func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
+  %c32_i64 = arith.constant 32 : i64
+  %cst = arith.constant 0.000000e+00 : f16
+  %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr = 16 : i64} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
+  return %pack : tensor<128x?x100x16x1xf16>
+}

The creation of pack op was dropping custom attached attributes.
@joker-eph
Copy link
Collaborator

Nit: these aren't "custom" attributes, but "discardable" attributes, please update title and description.

Also, it's not clear to me why this is correct in the full generality: how do you know that the transformation isn't invalidating some of these attributes potentially?

@pashu123 pashu123 changed the title [MLIR] Don't drop attached custom attributes [MLIR] Don't drop attached discardable attributes Oct 7, 2024
@pashu123
Copy link
Member Author

pashu123 commented Oct 7, 2024

Nit: these aren't "custom" attributes, but "discardable" attributes, please update title and description.

Also, it's not clear to me why this is correct in the full generality: how do you know that the transformation isn't invalidating some of these attributes potentially?

I've updated the title. Thanks for the suggestion. This is not correct in full generality. For this patch, the only concern is canonicalize dropping the discardable attrs, which shouldn't be the case.

@hanhanW
Copy link
Contributor

hanhanW commented Oct 7, 2024

Also, it's not clear to me why this is correct in the full generality: how do you know that the transformation isn't invalidating some of these attributes potentially?

This is not correct in full generality. For this patch, the only concern is canonicalize dropping the discardable attrs, which shouldn't be the case.

+1, it is a canonicalization pattern which infers static shapes when possible. The transformations are not applied yet, so we don't drop the attributes.

This is also what's happening in Linalg transformations. Some old school patterns look at linalg ops with __linalg_transformation__ attribute, which uses it as a matcher. And we also have static shape inference in LinalgOp canonicalization patterns, which do not drop the discardable attributes.

@pashu123 pashu123 merged commit 971b579 into llvm:main Oct 7, 2024
9 checks passed
Value newOp = rewriter.create<tensor::PackOp>(
loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
packOp.getPaddingValue(), packOp.getOuterDimsPerm());
auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry didnt get to review before it landed but you could use

auto discardableAttributes = getPrunedAttributeList(packOp, PackOp::getAttributesList());
clonedPackOp.setAttrs(discardableAttributes);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Looks clean!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Operation::getDiscardableAttrDictionary() method seems like a more direct implementation than getPrunedAttributeList() (why isn't this one living in mlir/IR by the way??)

That said, I didn't notice in the review that you're cloning the op here: why is that? Why aren't you just modifying it in place? Since you were doing rewriter.startOpModification I was assuming that this is what is happening.
When we can avoid recreating an operation and destroying the original, it's just more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you suggest how to modify it inplace without cloning in this case?

Copy link
Collaborator

@joker-eph joker-eph Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shortest code should be:

  rewriter.modifyOpInPlace(packOp, [&] {
     packOp.getSourceMutable().assign(source);
     packOp.getDestMutable().assign(dest);
     packOp.getResult();.setType(dest.getType());
 });

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was adding tensor.cast after the modification.

rewriter.replaceOpWithNewOp<tensor::CastOp>(
I can't replace it with tensor.pack.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not trivial to do in-place modification. There are two situations:

  1. The dest tensor type is the same. In this case, we do not need a tensor.cast consumer.
  2. The dest tensor type is changed. In this case, we need to create the tensor.cast op which makes the types consistent.

In the (1) situation, we can do in-place modification -- which is very simple.

In the (2) situation, it is not trivial because you need to replace the original op with the new tensor.cast op. If we do in-place modification, I don't see a trivial way to replace the op. Perhaps we can replace the uses of the tensor.pack ops with the new tensor.cast op, when it is the case.

IMO, cloning an op is cheap in this case. Instead of adding complex to logics, I'm +1 on cloning the op approach.

Note that this is also what we're doing for LinalgOps and it's been there for a couple years. I'm not saying that this is the correct way, but it's more like providing data points.

// Clone op.
Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
SmallVector<Value> replacements;
replacements.reserve(newOp->getNumResults());
for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
Value newResult = std::get<1>(it);
Value oldResult = std::get<0>(it);
Type newType = newResult.getType();
Type oldType = oldResult.getType();
replacements.push_back(
(newType != oldType)
? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
: newResult);
}
rewriter.replaceOp(linalgOp, replacements);

Copy link
Collaborator

@joker-eph joker-eph Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two situations:

  • The dest tensor type is the same. In this case, we do not need a tensor.cast consumer.
  • The dest tensor type is changed. In this case, we need to create the tensor.cast op which makes the types consistent.

The current code already does not try differentiate between these I believe: it always creates the cast, which is folded later if the types were matching.

I think it is not trivial to do in-place modification.

I don't understand the complexity you're foreseeing actually?
I sent a PR implementing it: #111593

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants