@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
1548
1548
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1549
1549
MemRef_Op<mnemonic, !listconcat(traits,
1550
1550
[Pure, ViewLikeOpInterface])>,
1551
- Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
1552
1551
Results<(outs AnyStridedMemRef:$result)>{
1553
1552
1554
1553
code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1573
1572
Value getViewSource() { return getSrc(); }
1574
1573
}];
1575
1574
1576
- let assemblyFormat = [{
1577
- $src $reassociation attr-dict `:` type($src) `into` type($result)
1578
- }];
1579
-
1580
1575
let hasFolder = 1;
1581
1576
let hasCanonicalizer = 1;
1582
1577
let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1598
1593
Example:
1599
1594
1600
1595
```mlir
1601
- %r = memref.expand_shape %0 [[0, 1], [2]]
1602
- : memref<?x?xf32 > into memref<?x5x?xf32 >
1596
+ %r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32]
1597
+ : memref<?x32xf32 > into memref<?x?x32xf32 >
1603
1598
```
1604
1599
1605
- At most one dimension of a reassociation group (e.g., [0, 1] above) may be
1606
- dynamic in the result type. Otherwise, the op would be ambiguous, as it
1607
- would not be clear how the source dimension is extended.
1608
-
1609
1600
If an op can be statically proven to be invalid (e.g, an expansion from
1610
1601
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
1611
1602
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,29 +1613,72 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1622
1613
there must be a dynamic result dimension in the corresponding reassociation
1623
1614
group. Same for strides.
1624
1615
1616
+ The representation for the output shape supports a partially-static
1617
+ specification via attributes specified through the `static_output_shape`
1618
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1619
+ corresponding entry has a dynamic value. There must be exactly as many SSA
1620
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1621
+ `static_output_shape`.
1622
+
1625
1623
Note: This op currently assumes that the inner strides are of the
1626
1624
source/result layout map are the faster-varying ones.
1627
1625
}];
1628
1626
1627
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
1628
+ Variadic<Index>:$output_shape,
1629
+ DenseI64ArrayAttr:$static_output_shape);
1630
+
1631
+ let assemblyFormat = [{
1632
+ $src $reassociation `output_shape`
1633
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1634
+ type($src) `into` type($result)
1635
+ }];
1636
+
1629
1637
let builders = [
1630
1638
// Builders using ReassociationIndices.
1639
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
1640
+ "ArrayRef<ReassociationIndices>":$reassociation),
1641
+ [{
1642
+ SmallVector<OpFoldResult> inputShape =
1643
+ getMixedSizes($_builder, $_state.location, src);
1644
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
1645
+ auto status =
1646
+ inferOutputShape($_builder, $_state.location,
1647
+ resultType.cast<MemRefType>(),
1648
+ reassociation, inputShape, outputShape);
1649
+ (void) status;
1650
+ assert(succeeded(status) && "unable to infer output shape");
1651
+ build($_builder, $_state, resultType.cast<MemRefType>(), src,
1652
+ getReassociationIndicesAttribute($_builder, reassociation),
1653
+ outputShape.second, outputShape.first);
1654
+ }]>,
1631
1655
OpBuilder<(ins "Type":$resultType, "Value":$src,
1632
1656
"ArrayRef<ReassociationIndices>":$reassociation,
1633
- CArg< "ArrayRef<NamedAttribute>", "{}">:$attrs ),
1657
+ "ArrayRef<OpFoldResult>":$outputShape ),
1634
1658
[{
1635
- build($_builder, $_state, resultType, src, attrs);
1636
- $_state.addAttribute("reassociation",
1637
- getReassociationIndicesAttribute($_builder, reassociation));
1659
+ auto [staticOutputShape, dynamicOutputShape] =
1660
+ decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1661
+ build($_builder, $_state, resultType, src,
1662
+ getReassociationIndicesAttribute($_builder, reassociation),
1663
+ dynamicOutputShape, staticOutputShape);
1638
1664
}]>,
1639
1665
1640
1666
// Builder using ReassociationExprs.
1667
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
1668
+ "ArrayRef<ReassociationExprs>":$reassociation),
1669
+ [{
1670
+ auto reassociationIndices =
1671
+ convertReassociationMapsToIndices(reassociation);
1672
+ build($_builder, $_state, resultType, src, reassociationIndices);
1673
+ }]>,
1641
1674
OpBuilder<(ins "Type":$resultType, "Value":$src,
1642
1675
"ArrayRef<ReassociationExprs>":$reassociation,
1643
- CArg< "ArrayRef<NamedAttribute>", "{}">:$attrs ),
1676
+ "ArrayRef<OpFoldResult>":$outputShape ),
1644
1677
[{
1645
1678
auto reassociationMaps =
1646
- convertReassociationMapsToIndices($_builder, reassociation);
1647
- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1679
+ convertReassociationMapsToIndices(reassociation);
1680
+ build($_builder, $_state, resultType, src, reassociationMaps,
1681
+ outputShape);
1648
1682
}]>,
1649
1683
1650
1684
// Builder that infers the result layout map. The result shape must be
@@ -1657,6 +1691,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1657
1691
static FailureOr<MemRefType> computeExpandedType(
1658
1692
MemRefType srcType, ArrayRef<int64_t> resultShape,
1659
1693
ArrayRef<ReassociationIndices> reassociation);
1694
+
1695
+ // Infer the output shape for a memref.expand_shape when it is possible
1696
+ // to do so.
1697
+ static LogicalResult inferOutputShape(
1698
+ OpBuilder &b, Location loc, MemRefType expandedType,
1699
+ ArrayRef<ReassociationIndices> reassociation,
1700
+ ArrayRef<OpFoldResult> inputShape,
1701
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
1660
1702
}];
1661
1703
1662
1704
let hasVerifier = 1;
@@ -1707,6 +1749,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1707
1749
source/result layout map are the faster-varying ones.
1708
1750
}];
1709
1751
1752
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
1753
+
1754
+ let assemblyFormat = [{
1755
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
1756
+ }];
1757
+
1710
1758
let builders = [
1711
1759
// Builders for a contracting reshape whose result type is computed from
1712
1760
// `src` and `reassociation`.
@@ -1718,7 +1766,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1718
1766
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1719
1767
[{
1720
1768
auto reassociationMaps =
1721
- convertReassociationMapsToIndices($_builder, reassociation);
1769
+ convertReassociationMapsToIndices(reassociation);
1722
1770
build($_builder, $_state, src, reassociationMaps, attrs);
1723
1771
}]>,
1724
1772
@@ -1736,7 +1784,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1736
1784
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1737
1785
[{
1738
1786
auto reassociationMaps =
1739
- convertReassociationMapsToIndices($_builder, reassociation);
1787
+ convertReassociationMapsToIndices(reassociation);
1740
1788
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1741
1789
}]>
1742
1790
];
0 commit comments