@@ -1554,7 +1554,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
1554
1554
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1555
1555
MemRef_Op<mnemonic, !listconcat(traits,
1556
1556
[Pure, ViewLikeOpInterface])>,
1557
- Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
1558
1557
Results<(outs AnyStridedMemRef:$result)>{
1559
1558
1560
1559
code commonExtraClassDeclaration = [{
@@ -1579,10 +1578,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1579
1578
Value getViewSource() { return getSrc(); }
1580
1579
}];
1581
1580
1582
- let assemblyFormat = [{
1583
- $src $reassociation attr-dict `:` type($src) `into` type($result)
1584
- }];
1585
-
1586
1581
let hasFolder = 1;
1587
1582
let hasCanonicalizer = 1;
1588
1583
let hasVerifier = 1;
@@ -1604,14 +1599,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1604
1599
Example:
1605
1600
1606
1601
```mlir
1607
- %r = memref.expand_shape %0 [[0, 1], [2]]
1608
- : memref<?x?xf32 > into memref<?x5x?xf32 >
1602
+ %r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32]
1603
+ : memref<?x32xf32 > into memref<?x?x32xf32 >
1609
1604
```
1610
1605
1611
- At most one dimension of a reassociation group (e.g., [0, 1] above) may be
1612
- dynamic in the result type. Otherwise, the op would be ambiguous, as it
1613
- would not be clear how the source dimension is extended.
1614
-
1615
1606
If an op can be statically proven to be invalid (e.g, an expansion from
1616
1607
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
1617
1608
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1628,29 +1619,72 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1628
1619
there must be a dynamic result dimension in the corresponding reassociation
1629
1620
group. Same for strides.
1630
1621
1622
+ The representation for the output shape supports a partially-static
1623
+ specification via attributes specified through the `static_output_shape`
1624
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1625
+ corresponding entry has a dynamic value. There must be exactly as many SSA
1626
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1627
+ `static_output_shape`.
1628
+
1631
1629
Note: This op currently assumes that the inner strides are of the
1632
1630
source/result layout map are the faster-varying ones.
1633
1631
}];
1634
1632
1633
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
1634
+ Variadic<Index>:$output_shape,
1635
+ DenseI64ArrayAttr:$static_output_shape);
1636
+
1637
+ let assemblyFormat = [{
1638
+ $src $reassociation `output_shape`
1639
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1640
+ type($src) `into` type($result)
1641
+ }];
1642
+
1635
1643
let builders = [
1636
1644
// Builders using ReassociationIndices.
1645
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
1646
+ "ArrayRef<ReassociationIndices>":$reassociation),
1647
+ [{
1648
+ SmallVector<OpFoldResult> inputShape =
1649
+ getMixedSizes($_builder, $_state.location, src);
1650
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
1651
+ auto status =
1652
+ inferOutputShape($_builder, $_state.location,
1653
+ resultType.cast<MemRefType>(),
1654
+ reassociation, inputShape, outputShape);
1655
+ (void) status;
1656
+ assert(succeeded(status) && "unable to infer output shape");
1657
+ build($_builder, $_state, resultType.cast<MemRefType>(), src,
1658
+ getReassociationIndicesAttribute($_builder, reassociation),
1659
+ outputShape.second, outputShape.first);
1660
+ }]>,
1637
1661
OpBuilder<(ins "Type":$resultType, "Value":$src,
1638
1662
"ArrayRef<ReassociationIndices>":$reassociation,
1639
- CArg< "ArrayRef<NamedAttribute>", "{}">:$attrs ),
1663
+ "ArrayRef<OpFoldResult>":$outputShape ),
1640
1664
[{
1641
- build($_builder, $_state, resultType, src, attrs);
1642
- $_state.addAttribute("reassociation",
1643
- getReassociationIndicesAttribute($_builder, reassociation));
1665
+ auto [staticOutputShape, dynamicOutputShape] =
1666
+ decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1667
+ build($_builder, $_state, resultType, src,
1668
+ getReassociationIndicesAttribute($_builder, reassociation),
1669
+ dynamicOutputShape, staticOutputShape);
1644
1670
}]>,
1645
1671
1646
1672
// Builder using ReassociationExprs.
1673
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
1674
+ "ArrayRef<ReassociationExprs>":$reassociation),
1675
+ [{
1676
+ auto reassociationIndices =
1677
+ convertReassociationMapsToIndices(reassociation);
1678
+ build($_builder, $_state, resultType, src, reassociationIndices);
1679
+ }]>,
1647
1680
OpBuilder<(ins "Type":$resultType, "Value":$src,
1648
1681
"ArrayRef<ReassociationExprs>":$reassociation,
1649
- CArg< "ArrayRef<NamedAttribute>", "{}">:$attrs ),
1682
+ "ArrayRef<OpFoldResult>":$outputShape ),
1650
1683
[{
1651
1684
auto reassociationMaps =
1652
- convertReassociationMapsToIndices($_builder, reassociation);
1653
- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1685
+ convertReassociationMapsToIndices(reassociation);
1686
+ build($_builder, $_state, resultType, src, reassociationMaps,
1687
+ outputShape);
1654
1688
}]>,
1655
1689
1656
1690
// Builder that infers the result layout map. The result shape must be
@@ -1663,6 +1697,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1663
1697
static FailureOr<MemRefType> computeExpandedType(
1664
1698
MemRefType srcType, ArrayRef<int64_t> resultShape,
1665
1699
ArrayRef<ReassociationIndices> reassociation);
1700
+
1701
+ // Infer the output shape for a memref.expand_shape when it is possible
1702
+ // to do so.
1703
+ static LogicalResult inferOutputShape(
1704
+ OpBuilder &b, Location loc, MemRefType expandedType,
1705
+ ArrayRef<ReassociationIndices> reassociation,
1706
+ ArrayRef<OpFoldResult> inputShape,
1707
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
1666
1708
}];
1667
1709
1668
1710
let hasVerifier = 1;
@@ -1713,6 +1755,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1713
1755
source/result layout map are the faster-varying ones.
1714
1756
}];
1715
1757
1758
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
1759
+
1760
+ let assemblyFormat = [{
1761
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
1762
+ }];
1763
+
1716
1764
let builders = [
1717
1765
// Builders for a contracting reshape whose result type is computed from
1718
1766
// `src` and `reassociation`.
@@ -1724,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1724
1772
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1725
1773
[{
1726
1774
auto reassociationMaps =
1727
- convertReassociationMapsToIndices($_builder, reassociation);
1775
+ convertReassociationMapsToIndices(reassociation);
1728
1776
build($_builder, $_state, src, reassociationMaps, attrs);
1729
1777
}]>,
1730
1778
@@ -1742,7 +1790,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1742
1790
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1743
1791
[{
1744
1792
auto reassociationMaps =
1745
- convertReassociationMapsToIndices($_builder, reassociation);
1793
+ convertReassociationMapsToIndices(reassociation);
1746
1794
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1747
1795
}]>
1748
1796
];
0 commit comments