Skip to content

Commit 22afbba

Browse files
Tai78641Jerry-Ge
authored andcommitted
[mlir][tosa] Update TileOp infer shape
update to use getConstShapeValues in TileOp's shape inference Signed-off-by: Tai Ly <[email protected]> Change-Id: Ie40d127eec16f7edfa3121ed9dcd5e3134138754
1 parent ad9f15a commit 22afbba

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,33 +1612,43 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
16121612
MLIRContext *context, ::std::optional<Location> location,
16131613
TileOp::Adaptor adaptor,
16141614
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1615-
DenseIntElementsAttr multiplesAttr;
1616-
if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
1617-
return failure();
1618-
1619-
SmallVector<int64_t> multiples = llvm::to_vector(
1620-
llvm::map_range(multiplesAttr.getValues<APInt>(),
1621-
[](const APInt &val) { return val.getSExtValue(); }));
1615+
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1616+
SmallVector<int64_t> multiples;
1617+
if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
1618+
multiples)) {
1619+
auto rank =
1620+
cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
1621+
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1622+
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1623+
return success();
1624+
} else {
1625+
multiples = convertToMlirShape(multiples);
1626+
}
16221627

16231628
ShapeAdaptor inputShape(adaptor.getInput1().getType());
16241629
SmallVector<int64_t> outputShape;
16251630
if (!inputShape.hasRank()) {
16261631
outputShape.resize(multiples.size(), ShapedType::kDynamic);
1627-
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1632+
inferredReturnShapes.push_back(
1633+
ShapedTypeComponents(outputShape, inputType));
16281634
return success();
16291635
} else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
16301636
return failure();
16311637

16321638
// Any non dynamic dimension can be multiplied to a known size.
16331639
outputShape.reserve(multiples.size());
16341640
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1635-
int64_t dim = inputShape.getDimSize(i);
1636-
if (dim != ShapedType::kDynamic)
1637-
dim *= multiples[i];
1638-
outputShape.push_back(dim);
1641+
if (multiples[i] == ShapedType::kDynamic) {
1642+
outputShape.push_back(ShapedType::kDynamic);
1643+
} else {
1644+
int64_t dim = inputShape.getDimSize(i);
1645+
if (dim != ShapedType::kDynamic)
1646+
dim *= multiples[i];
1647+
outputShape.push_back(dim);
1648+
}
16391649
}
16401650

1641-
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1651+
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
16421652
return success();
16431653
}
16441654

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,17 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
599599

600600
// -----
601601

602+
// CHECK-LABEL: @test_tile_unknown_multiples
603+
func.func @test_tile_unknown_multiples(%arg0 : tensor<2x3x?xi32>) -> () {
604+
// CHECK: %[[CST:.*]] = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
605+
// CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x?x?xi32>
606+
%cst = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
607+
%0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
608+
return
609+
}
610+
611+
// -----
612+
602613
// CHECK-LABEL: @test_transpose_static
603614
func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
604615
// CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>

0 commit comments

Comments
 (0)