Skip to content

Commit b3a13e0

Browse files
committed
fix unpack
1 parent dff7ad9 commit b3a13e0

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4242,15 +4242,17 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
42424242
op.getDestType().getShape().end());
42434243
llvm::SmallSetVector<int64_t, 4> innerDims;
42444244
innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4245-
auto outerDimsPerm = op.getOuterDimsPerm();
4245+
SmallVector<int64_t> inverseOuterDimsPerm;
4246+
if (!op.getOuterDimsPerm().empty())
4247+
inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
42464248
int destRank = op.getDestRank();
42474249
for (auto i : llvm::seq<int64_t>(0, destRank)) {
42484250
if (innerDims.contains(i))
42494251
continue;
42504252
int64_t srcPos = i;
42514253
int64_t destPos = i;
4252-
if (!outerDimsPerm.empty())
4253-
srcPos = outerDimsPerm[destPos];
4254+
if (!inverseOuterDimsPerm.empty())
4255+
srcPos = inverseOuterDimsPerm[destPos];
42544256
if (ShapedType::isDynamic(srcShape[srcPos]) ==
42554257
ShapedType::isDynamic(destShape[destPos])) {
42564258
continue;

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -920,9 +920,9 @@ func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tens
920920
// CHECK-LABEL: func.func @infer_dest_shape_unpack
921921
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
922922
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
923-
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
923+
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
924924
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
925-
// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
925+
// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<40x20x?x30xf32> to tensor<?x?x?x?xf32>
926926
// CHECK: return %[[CAST_UNPACK]]
927927

928928
// -----
@@ -938,7 +938,7 @@ func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30
938938
// CHECK-LABEL: func.func @infer_src_shape_unpack
939939
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
940940
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
941-
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
941+
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
942942
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
943943
// CHECK: return %[[UNPACK]]
944944

0 commit comments

Comments
 (0)