Skip to content

Commit 751c3f5

Browse files
Jerry-GeTai78641
andauthored
[mlir][tosa] Update TileOp infer shape (#134732)
update to use getConstShapeValues in TileOp's shape inference Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Tai Ly <[email protected]>
1 parent 2a7f12e commit 751c3f5

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,33 +1616,43 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
16161616
MLIRContext *context, ::std::optional<Location> location,
16171617
TileOp::Adaptor adaptor,
16181618
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1619-
DenseIntElementsAttr multiplesAttr;
1620-
if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
1621-
return failure();
1622-
1623-
SmallVector<int64_t> multiples = llvm::to_vector(
1624-
llvm::map_range(multiplesAttr.getValues<APInt>(),
1625-
[](const APInt &val) { return val.getSExtValue(); }));
1619+
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1620+
SmallVector<int64_t> multiples;
1621+
if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
1622+
multiples)) {
1623+
auto rank =
1624+
cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
1625+
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1626+
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1627+
return success();
1628+
} else {
1629+
multiples = convertToMlirShape(multiples);
1630+
}
16261631

16271632
ShapeAdaptor inputShape(adaptor.getInput1().getType());
16281633
SmallVector<int64_t> outputShape;
16291634
if (!inputShape.hasRank()) {
16301635
outputShape.resize(multiples.size(), ShapedType::kDynamic);
1631-
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1636+
inferredReturnShapes.push_back(
1637+
ShapedTypeComponents(outputShape, inputType));
16321638
return success();
16331639
} else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
16341640
return failure();
16351641

16361642
// Any non dynamic dimension can be multiplied to a known size.
16371643
outputShape.reserve(multiples.size());
16381644
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1639-
int64_t dim = inputShape.getDimSize(i);
1640-
if (dim != ShapedType::kDynamic)
1641-
dim *= multiples[i];
1642-
outputShape.push_back(dim);
1645+
if (multiples[i] == ShapedType::kDynamic) {
1646+
outputShape.push_back(ShapedType::kDynamic);
1647+
} else {
1648+
int64_t dim = inputShape.getDimSize(i);
1649+
if (dim != ShapedType::kDynamic)
1650+
dim *= multiples[i];
1651+
outputShape.push_back(dim);
1652+
}
16431653
}
16441654

1645-
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1655+
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
16461656
return success();
16471657
}
16481658

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,17 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
599599

600600
// -----
601601

602+
// CHECK-LABEL: @test_tile_unknown_multiples
603+
func.func @test_tile_unknown_multiples(%arg0 : tensor<2x3x?xi32>) -> () {
604+
// CHECK: %[[CST:.*]] = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
605+
// CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x?x?xi32>
606+
%cst = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
607+
%0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
608+
return
609+
}
610+
611+
// -----
612+
602613
// CHECK-LABEL: @test_transpose_static
603614
func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
604615
// CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>

0 commit comments

Comments
 (0)