Skip to content

[MLIR] Don't drop attached discardable attributes #111261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4337,11 +4337,16 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
dest =
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
}
Value newOp = rewriter.create<tensor::PackOp>(
loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
packOp.getPaddingValue(), packOp.getOuterDimsPerm());
auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry didnt get to review before it landed but you could use

auto discardableAttributes = getPrunedAttributeList(packOp, PackOp::getAttributesList());
clonedPackOp.setAttrs(discardableAttributes);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Looks clean!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Operation::getDiscardableAttrDictionary() method seems like a more direct implementation than getPrunedAttributeList() (why isn't this one living in mlir/IR by the way??)

That said, I didn't notice in the review that you're cloning the op here: why is that? Why aren't you just modifying it in place? Since you were doing rewriter.startOpModification I was assuming that this is what is happening.
When we can avoid recreating an operation and destroying the original, it's just more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you suggest how to modify it inplace without cloning in this case?

Copy link
Collaborator

@joker-eph joker-eph Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shortest code should be:

  rewriter.modifyOpInPlace(packOp, [&] {
     packOp.getSourceMutable().assign(source);
     packOp.getDestMutable().assign(dest);
     packOp.getResult();.setType(dest.getType());
 });

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was adding tensor.cast after the modification.

rewriter.replaceOpWithNewOp<tensor::CastOp>(
I can't replace it with tensor.pack.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not trivial to do in-place modification. There are two situations:

  1. The dest tensor type is the same. In this case, we do not need a tensor.cast consumer.
  2. The dest tensor type is changed. In this case, we need to create the tensor.cast op which makes the types consistent.

In the (1) situation, we can do in-place modification -- which is very simple.

In the (2) situation, it is not trivial because you need to replace the original op with the new tensor.cast op. If we do in-place modification, I don't see a trivial way to replace the op. Perhaps we can replace the uses of the tensor.pack ops with the new tensor.cast op, when it is the case.

IMO, cloning an op is cheap in this case. Instead of adding complex to logics, I'm +1 on cloning the op approach.

Note that this is also what we're doing for LinalgOps and it's been there for a couple years. I'm not saying that this is the correct way, but it's more like providing data points.

// Clone op.
Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
SmallVector<Value> replacements;
replacements.reserve(newOp->getNumResults());
for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
Value newResult = std::get<1>(it);
Value oldResult = std::get<0>(it);
Type newType = newResult.getType();
Type oldType = oldResult.getType();
replacements.push_back(
(newType != oldType)
? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
: newResult);
}
rewriter.replaceOp(linalgOp, replacements);

Copy link
Collaborator

@joker-eph joker-eph Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two situations:

  • The dest tensor type is the same. In this case, we do not need a tensor.cast consumer.
  • The dest tensor type is changed. In this case, we need to create the tensor.cast op which makes the types consistent.

The current code already does not try differentiate between these I believe: it always creates the cast, which is folded later if the types were matching.

I think it is not trivial to do in-place modification.

I don't understand the complexity you're foreseeing actually?
I sent a PR implementing it: #111593

Value res = clonedPackOp.getResult();
rewriter.startOpModification(clonedPackOp);
clonedPackOp.getSourceMutable().assign(source);
clonedPackOp.getDestMutable().assign(dest);
res.setType(dest.getType());
rewriter.finalizeOpModification(clonedPackOp);

rewriter.replaceOpWithNewOp<tensor::CastOp>(
packOp, packOp.getResult().getType(), newOp);
packOp, packOp.getResult().getType(), clonedPackOp);
return success();
}

Expand Down
13 changes: 12 additions & 1 deletion mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2357,7 +2357,7 @@ func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>
%tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
%tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
%packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
%unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
%unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
return %unpacked : tensor<224x512xbf16>
}

Expand Down Expand Up @@ -2707,3 +2707,14 @@ func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tenso
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
return %0#1 : index
}

// -----

// CHECK-LABEL: func.func @pack_dont_drop_attributes(
// CHECK: tensor.pack {{.*}} {test_attr}
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f16
%pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
return %pack : tensor<128x?x100x16x1xf16>
}
Loading