Skip to content

Commit eac8604

Browse files
authored
[mlir][tensor] Add support for tensor.unpack static shapes inference. (#81702)
The revision does not refactor the inferStaticShape for pack and unpack ops because they can diverge quickly. Because there are more dimensions can be inferred (i.e., with inner_tile_sizes) if the pack op does not have padding value. This is a follow-up of #80848
1 parent c3b87a8 commit eac8604

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4229,6 +4229,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
42294229
metadata.outerDimsPerm);
42304230
}
42314231

4232+
/// Returns true if the `srcShape` or `destShape` is different from the one in
4233+
/// `op` and populates each with the inferred static shape.
4234+
static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4235+
SmallVectorImpl<int64_t> &destShape) {
4236+
bool changeNeeded = false;
4237+
srcShape.assign(op.getSourceType().getShape().begin(),
4238+
op.getSourceType().getShape().end());
4239+
destShape.assign(op.getDestType().getShape().begin(),
4240+
op.getDestType().getShape().end());
4241+
llvm::SmallSetVector<int64_t, 4> innerDims;
4242+
innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4243+
auto outerDimsPerm = op.getOuterDimsPerm();
4244+
int destRank = op.getDestRank();
4245+
for (auto i : llvm::seq<int64_t>(0, destRank)) {
4246+
if (innerDims.contains(i))
4247+
continue;
4248+
int64_t srcPos = i;
4249+
int64_t destPos = i;
4250+
if (!outerDimsPerm.empty())
4251+
srcPos = outerDimsPerm[destPos];
4252+
if (ShapedType::isDynamic(srcShape[srcPos]) ==
4253+
ShapedType::isDynamic(destShape[destPos])) {
4254+
continue;
4255+
}
4256+
int64_t size = srcShape[srcPos];
4257+
if (ShapedType::isDynamic(size))
4258+
size = destShape[destPos];
4259+
srcShape[srcPos] = size;
4260+
destShape[destPos] = size;
4261+
changeNeeded = true;
4262+
}
4263+
return changeNeeded;
4264+
}
4265+
42324266
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
42334267
PatternRewriter &rewriter) {
42344268
/// pack(unpack(x)) -> x
@@ -4251,6 +4285,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
42514285
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
42524286
return success();
42534287
}
4288+
4289+
// Insert tensor.cast ops if static shape inference is available..
4290+
SmallVector<int64_t> srcShape, destShape;
4291+
if (inferStaticShape(unPackOp, srcShape, destShape)) {
4292+
Location loc = unPackOp.getLoc();
4293+
Value source = unPackOp.getSource();
4294+
if (srcShape != unPackOp.getSourceType().getShape()) {
4295+
auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4296+
source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4297+
unPackOp.getSource());
4298+
}
4299+
Value dest = unPackOp.getDest();
4300+
if (destShape != unPackOp.getDestType().getShape()) {
4301+
auto newDestType = unPackOp.getDestType().clone(destShape);
4302+
dest =
4303+
rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4304+
}
4305+
Value newOp = rewriter.create<tensor::UnPackOp>(
4306+
loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4307+
unPackOp.getOuterDimsPerm());
4308+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
4309+
unPackOp, unPackOp.getResult().getType(), newOp);
4310+
return success();
4311+
}
4312+
42544313
return failure();
42554314
}
42564315

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128
909909

910910
// -----
911911

912+
func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
913+
%unpack = tensor.unpack %src
914+
outer_dims_perm = [2, 1, 3, 0]
915+
inner_dims_pos = [2]
916+
inner_tiles = [16]
917+
into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
918+
return %unpack : tensor<?x?x?x?xf32>
919+
}
920+
// CHECK-LABEL: func.func @infer_dest_shape_unpack
921+
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
922+
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
923+
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
924+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
925+
// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
926+
// CHECK: return %[[CAST_UNPACK]]
927+
928+
// -----
929+
930+
func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
931+
%unpack = tensor.unpack %src
932+
outer_dims_perm = [2, 1, 3, 0]
933+
inner_dims_pos = [2]
934+
inner_tiles = [16]
935+
into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
936+
return %unpack : tensor<30x20x?x10xf32>
937+
}
938+
// CHECK-LABEL: func.func @infer_src_shape_unpack
939+
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
940+
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
941+
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
942+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
943+
// CHECK: return %[[UNPACK]]
944+
945+
// -----
946+
912947
// CHECK-LABEL: func @fold_overlapping_insert
913948
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
914949
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
@@ -2176,3 +2211,19 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
21762211
} : tensor<?x8xi32>
21772212
return %tensor : tensor<?x8xi32>
21782213
}
2214+
2215+
// -----
2216+
2217+
func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> {
2218+
%dim1 = arith.constant 40 : index
2219+
%dim2 = arith.constant 80 : index
2220+
%tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2221+
%unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
2222+
%cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
2223+
%tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
2224+
%packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
2225+
return %packed : tensor<10x20x4x4xf32>
2226+
}
2227+
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
2228+
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
2229+
// CHECK: return %[[SRC]]

0 commit comments

Comments
 (0)