Skip to content

Commit bc08cc2

Browse files
authored
[mlir][tensor] Add support for tensor.pack static shapes inference. (#80848)
Fixes iree-org/iree#16317
1 parent 3f738a4 commit bc08cc2

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3983,6 +3983,41 @@ static bool paddingIsNotNeeded(PackOp op) {
39833983
op.getMixedTiles());
39843984
}
39853985

3986+
/// Returns true if the `srcShape` or `destShape` is different from the one in
3987+
/// `packOp` and populates each with the inferred static shape.
3988+
static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
3989+
SmallVectorImpl<int64_t> &destShape) {
3990+
bool changeNeeded = false;
3991+
srcShape.assign(packOp.getSourceType().getShape().begin(),
3992+
packOp.getSourceType().getShape().end());
3993+
destShape.assign(packOp.getDestType().getShape().begin(),
3994+
packOp.getDestType().getShape().end());
3995+
llvm::SmallSetVector<int64_t, 4> innerDims;
3996+
innerDims.insert(packOp.getInnerDimsPos().begin(),
3997+
packOp.getInnerDimsPos().end());
3998+
auto outerDimsPerm = packOp.getOuterDimsPerm();
3999+
int srcRank = packOp.getSourceRank();
4000+
for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4001+
if (innerDims.contains(i))
4002+
continue;
4003+
int64_t srcPos = i;
4004+
int64_t destPos = i;
4005+
if (!outerDimsPerm.empty())
4006+
destPos = outerDimsPerm[srcPos];
4007+
if (ShapedType::isDynamic(srcShape[srcPos]) ==
4008+
ShapedType::isDynamic(destShape[destPos])) {
4009+
continue;
4010+
}
4011+
int64_t size = srcShape[srcPos];
4012+
if (ShapedType::isDynamic(size))
4013+
size = destShape[destPos];
4014+
srcShape[srcPos] = size;
4015+
destShape[destPos] = size;
4016+
changeNeeded = true;
4017+
}
4018+
return changeNeeded;
4019+
}
4020+
39864021
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
39874022
// Fold an unpack(pack(x)) to x.
39884023
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
40034038
rewriter.finalizeOpModification(packOp);
40044039
return success();
40054040
}
4041+
4042+
// Insert tensor.cast ops if static shape inference is available..
4043+
SmallVector<int64_t> srcShape, destShape;
4044+
if (inferStaticShape(packOp, srcShape, destShape)) {
4045+
Location loc = packOp.getLoc();
4046+
Value source = packOp.getSource();
4047+
if (srcShape != packOp.getSourceType().getShape()) {
4048+
auto newSrcType = packOp.getSourceType().clone(srcShape);
4049+
source =
4050+
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4051+
}
4052+
Value dest = packOp.getDest();
4053+
if (destShape != packOp.getDestType().getShape()) {
4054+
auto newDestType = packOp.getDestType().clone(destShape);
4055+
dest =
4056+
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4057+
}
4058+
Value newOp = rewriter.create<tensor::PackOp>(
4059+
loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4060+
packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4061+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
4062+
packOp, packOp.getResult().getType(), newOp);
4063+
return success();
4064+
}
4065+
40064066
return failure();
40074067
}
40084068

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,45 @@ func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<312
809809

810810
// -----
811811

812+
func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
813+
%cst = arith.constant 0.000000e+00 : f32
814+
%pack = tensor.pack %src
815+
padding_value(%cst : f32)
816+
outer_dims_perm = [2, 1, 3, 0]
817+
inner_dims_pos = [2]
818+
inner_tiles = [16]
819+
into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
820+
return %pack : tensor<10x20x30x40x16xf32>
821+
}
822+
// CHECK-LABEL: func.func @infer_src_shape_pack
823+
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
824+
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
825+
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
826+
// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
827+
// CHECK: return %[[PACK]]
828+
829+
// -----
830+
831+
func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
832+
%cst = arith.constant 0.000000e+00 : f32
833+
%pack = tensor.pack %src
834+
padding_value(%cst : f32)
835+
outer_dims_perm = [2, 1, 3, 0]
836+
inner_dims_pos = [2]
837+
inner_tiles = [16]
838+
into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
839+
return %pack : tensor<?x?x?x?x16xf32>
840+
}
841+
// CHECK-LABEL: func.func @infer_dest_shape_pack
842+
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
843+
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
844+
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
845+
// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
846+
// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
847+
// CHECK: return %[[CAST_PACK]]
848+
849+
// -----
850+
812851
func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
813852
%cst = arith.constant 0.000000e+00 : f32
814853
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>

0 commit comments

Comments
 (0)