@@ -1612,33 +1612,43 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
1612
1612
MLIRContext *context, ::std::optional<Location> location,
1613
1613
TileOp::Adaptor adaptor,
1614
1614
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1615
- DenseIntElementsAttr multiplesAttr;
1616
- if (!matchPattern (adaptor.getMultiples (), m_Constant (&multiplesAttr)))
1617
- return failure ();
1618
-
1619
- SmallVector<int64_t > multiples = llvm::to_vector (
1620
- llvm::map_range (multiplesAttr.getValues <APInt>(),
1621
- [](const APInt &val) { return val.getSExtValue (); }));
1615
+ Type inputType = getElementTypeOrSelf (adaptor.getInput1 ().getType ());
1616
+ SmallVector<int64_t > multiples;
1617
+ if (!tosa::getConstShapeValues (adaptor.getMultiples ().getDefiningOp (),
1618
+ multiples)) {
1619
+ auto rank =
1620
+ cast<tosa::shapeType>(adaptor.getMultiples ().getType ()).getRank ();
1621
+ SmallVector<int64_t > fallback (rank, ShapedType::kDynamic );
1622
+ inferredReturnShapes.push_back (ShapedTypeComponents (fallback, inputType));
1623
+ return success ();
1624
+ } else {
1625
+ multiples = convertToMlirShape (multiples);
1626
+ }
1622
1627
1623
1628
ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
1624
1629
SmallVector<int64_t > outputShape;
1625
1630
if (!inputShape.hasRank ()) {
1626
1631
outputShape.resize (multiples.size (), ShapedType::kDynamic );
1627
- inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1632
+ inferredReturnShapes.push_back (
1633
+ ShapedTypeComponents (outputShape, inputType));
1628
1634
return success ();
1629
1635
} else if (static_cast <size_t >(inputShape.getRank ()) != multiples.size ())
1630
1636
return failure ();
1631
1637
1632
1638
// Any non dynamic dimension can be multiplied to a known size.
1633
1639
outputShape.reserve (multiples.size ());
1634
1640
for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
1635
- int64_t dim = inputShape.getDimSize (i);
1636
- if (dim != ShapedType::kDynamic )
1637
- dim *= multiples[i];
1638
- outputShape.push_back (dim);
1641
+ if (multiples[i] == ShapedType::kDynamic ) {
1642
+ outputShape.push_back (ShapedType::kDynamic );
1643
+ } else {
1644
+ int64_t dim = inputShape.getDimSize (i);
1645
+ if (dim != ShapedType::kDynamic )
1646
+ dim *= multiples[i];
1647
+ outputShape.push_back (dim);
1648
+ }
1639
1649
}
1640
1650
1641
- inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
1651
+ inferredReturnShapes.push_back (ShapedTypeComponents (outputShape, inputType ));
1642
1652
return success ();
1643
1653
}
1644
1654
0 commit comments