@@ -1616,33 +1616,43 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
1616
1616
MLIRContext *context, ::std::optional<Location> location,
1617
1617
TileOp::Adaptor adaptor,
1618
1618
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1619
- DenseIntElementsAttr multiplesAttr;
1620
- if (!matchPattern (adaptor.getMultiples (), m_Constant (&multiplesAttr)))
1621
- return failure ();
1622
-
1623
- SmallVector<int64_t > multiples = llvm::to_vector (
1624
- llvm::map_range (multiplesAttr.getValues <APInt>(),
1625
- [](const APInt &val) { return val.getSExtValue (); }));
1619
+ Type inputType = getElementTypeOrSelf (adaptor.getInput1 ().getType ());
1620
+ SmallVector<int64_t > multiples;
1621
+ if (!tosa::getConstShapeValues (adaptor.getMultiples ().getDefiningOp (),
1622
+ multiples)) {
1623
+ auto rank =
1624
+ cast<tosa::shapeType>(adaptor.getMultiples ().getType ()).getRank ();
1625
+ SmallVector<int64_t > fallback (rank, ShapedType::kDynamic );
1626
+ inferredReturnShapes.push_back (ShapedTypeComponents (fallback, inputType));
1627
+ return success ();
1628
+ } else {
1629
+ multiples = convertToMlirShape (multiples);
1630
+ }
1626
1631
1627
1632
ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
1628
1633
SmallVector<int64_t > outputShape;
1629
1634
if (!inputShape.hasRank ()) {
1630
1635
outputShape.resize (multiples.size (), ShapedType::kDynamic );
1631
- inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1636
+ inferredReturnShapes.push_back (
1637
+ ShapedTypeComponents (outputShape, inputType));
1632
1638
return success ();
1633
1639
} else if (static_cast <size_t >(inputShape.getRank ()) != multiples.size ())
1634
1640
return failure ();
1635
1641
1636
1642
// Any non dynamic dimension can be multiplied to a known size.
1637
1643
outputShape.reserve (multiples.size ());
1638
1644
for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
1639
- int64_t dim = inputShape.getDimSize (i);
1640
- if (dim != ShapedType::kDynamic )
1641
- dim *= multiples[i];
1642
- outputShape.push_back (dim);
1645
+ if (multiples[i] == ShapedType::kDynamic ) {
1646
+ outputShape.push_back (ShapedType::kDynamic );
1647
+ } else {
1648
+ int64_t dim = inputShape.getDimSize (i);
1649
+ if (dim != ShapedType::kDynamic )
1650
+ dim *= multiples[i];
1651
+ outputShape.push_back (dim);
1652
+ }
1643
1653
}
1644
1654
1645
- inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1655
+ inferredReturnShapes.push_back (ShapedTypeComponents (outputShape, inputType ));
1646
1656
return success ();
1647
1657
}
1648
1658
0 commit comments