Skip to content

[mlir][tensor] Add support for tensor.pack static shapes inference. #80848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3983,6 +3983,41 @@ static bool paddingIsNotNeeded(PackOp op) {
op.getMixedTiles());
}

/// Returns true if the `srcShape` or `destShape` is different from the one in
/// `packOp` and populates each with the inferred static shape.
static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
SmallVectorImpl<int64_t> &destShape) {
bool changeNeeded = false;
srcShape.assign(packOp.getSourceType().getShape().begin(),
packOp.getSourceType().getShape().end());
destShape.assign(packOp.getDestType().getShape().begin(),
packOp.getDestType().getShape().end());
llvm::SmallSetVector<int64_t, 4> innerDims;
innerDims.insert(packOp.getInnerDimsPos().begin(),
packOp.getInnerDimsPos().end());
auto outerDimsPerm = packOp.getOuterDimsPerm();
int srcRank = packOp.getSourceRank();
for (auto i : llvm::seq<int64_t>(0, srcRank)) {
if (innerDims.contains(i))
continue;
int64_t srcPos = i;
int64_t destPos = i;
if (!outerDimsPerm.empty())
destPos = outerDimsPerm[srcPos];
if (ShapedType::isDynamic(srcShape[srcPos]) ==
ShapedType::isDynamic(destShape[destPos])) {
continue;
}
int64_t size = srcShape[srcPos];
if (ShapedType::isDynamic(size))
size = destShape[destPos];
srcShape[srcPos] = size;
destShape[destPos] = size;
changeNeeded = true;
}
return changeNeeded;
}

LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// Fold an unpack(pack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
Expand All @@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.finalizeOpModification(packOp);
return success();
}

// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(packOp, srcShape, destShape)) {
Location loc = packOp.getLoc();
Value source = packOp.getSource();
if (srcShape != packOp.getSourceType().getShape()) {
auto newSrcType = packOp.getSourceType().clone(srcShape);
source =
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
if (destShape != packOp.getDestType().getShape()) {
auto newDestType = packOp.getDestType().clone(destShape);
dest =
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
}
Value newOp = rewriter.create<tensor::PackOp>(
loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
packOp.getPaddingValue(), packOp.getOuterDimsPerm());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
packOp, packOp.getResult().getType(), newOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this well suited for a canonicalization? I'm wondering about cases where a pack and unpack could have folded away but this pattern introduces a tensor.cast in the middle. Maybe we need the same pattern for unpack too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is common in Linalg and I think we can add such functionality to tensor ops. Good point on pack/unpack folding. I think they can be folded if the order of applying patterns is correct. To make the result IR deterministic, we might need it for unpack. So yes, I will prepare a patch and send it out for review. Do you think it is better to land both patterns together? If so, I will put the update to the PR. If it does not matter, I will land it as a follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a possibility of non-deterministic IR probably best to land at the same time. In this case the pack-unpack canonicalization seems to always apply before this casting pattern so maybe it is ok here? Hard to predict exactly what the pattern applicator could do for any input IR though.

return success();
}

return failure();
}

Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,45 @@ func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<312

// -----

func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
%cst = arith.constant 0.000000e+00 : f32
%pack = tensor.pack %src
padding_value(%cst : f32)
outer_dims_perm = [2, 1, 3, 0]
inner_dims_pos = [2]
inner_tiles = [16]
into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
return %pack : tensor<10x20x30x40x16xf32>
}
// CHECK-LABEL: func.func @infer_src_shape_pack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
// CHECK: return %[[PACK]]

// -----

func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
%cst = arith.constant 0.000000e+00 : f32
%pack = tensor.pack %src
padding_value(%cst : f32)
outer_dims_perm = [2, 1, 3, 0]
inner_dims_pos = [2]
inner_tiles = [16]
into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
return %pack : tensor<?x?x?x?x16xf32>
}
// CHECK-LABEL: func.func @infer_dest_shape_pack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
// CHECK: return %[[CAST_PACK]]

// -----

func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
Expand Down