@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
1548
1548
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1549
1549
MemRef_Op<mnemonic, !listconcat(traits,
1550
1550
[Pure, ViewLikeOpInterface])>,
1551
- Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
1552
1551
Results<(outs AnyStridedMemRef:$result)>{
1553
1552
1554
1553
code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1573
1572
Value getViewSource() { return getSrc(); }
1574
1573
}];
1575
1574
1576
- let assemblyFormat = [{
1577
- $src $reassociation attr-dict `:` type($src) `into` type($result)
1578
- }];
1579
-
1580
1575
let hasFolder = 1;
1581
1576
let hasCanonicalizer = 1;
1582
1577
let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1598
1593
Example:
1599
1594
1600
1595
```mlir
1601
- %r = memref.expand_shape %0 [[0, 1], [2]]
1602
- : memref<?x?xf32 > into memref<?x5x?xf32 >
1596
+ %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
1597
+ : memref<?x32xf32 > into memref<?x?x32xf32 >
1603
1598
```
1604
1599
1605
- At most one dimension of a reassociation group (e.g., [0, 1] above) may be
1606
- dynamic in the result type. Otherwise, the op would be ambiguous, as it
1607
- would not be clear how the source dimension is extended.
1608
-
1609
1600
If an op can be statically proven to be invalid (e.g, an expansion from
1610
1601
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
1611
1602
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,41 +1613,80 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1622
1613
there must be a dynamic result dimension in the corresponding reassociation
1623
1614
group. Same for strides.
1624
1615
1616
+ The representation for the output shape supports a partially-static
1617
+ specification via attributes specified through the `static_output_shape`
1618
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1619
+ corresponding entry has a dynamic value. There must be exactly as many SSA
1620
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1621
+ `static_output_shape`.
1622
+
1625
1623
Note: This op currently assumes that the inner strides are of the
1626
1624
source/result layout map are the faster-varying ones.
1627
1625
}];
1628
1626
1627
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
1628
+ Variadic<Index>:$output_shape,
1629
+ DenseI64ArrayAttr:$static_output_shape);
1630
+
1631
+ let assemblyFormat = [{
1632
+ $src $reassociation `output_shape`
1633
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1634
+ type($src) `into` type($result)
1635
+ }];
1636
+
1629
1637
let builders = [
1630
1638
// Builders using ReassociationIndices.
1631
1639
OpBuilder<(ins "Type":$resultType, "Value":$src,
1632
1640
"ArrayRef<ReassociationIndices>":$reassociation,
1633
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1641
+ "ArrayRef<OpFoldResult>":$outputShape)>,
1642
+
1643
+ // It will infer output shape using inferOutputShape() method.
1644
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
1645
+ "ArrayRef<ReassociationIndices>":$reassociation)>,
1646
+
1647
+ // Builder using ReassociationExprs.
1648
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
1649
+ "ArrayRef<ReassociationExprs>":$reassociation),
1634
1650
[{
1635
- build($_builder, $_state, resultType, src, attrs);
1636
- $_state.addAttribute(" reassociation",
1637
- getReassociationIndicesAttribute ($_builder, reassociation) );
1651
+ auto reassociationIndices =
1652
+ convertReassociationMapsToIndices( reassociation);
1653
+ build ($_builder, $_state, resultType, src, reassociationIndices );
1638
1654
}]>,
1639
1655
1640
- // Builder using ReassociationExprs.
1641
1656
OpBuilder<(ins "Type":$resultType, "Value":$src,
1642
1657
"ArrayRef<ReassociationExprs>":$reassociation,
1643
- CArg< "ArrayRef<NamedAttribute>", "{}">:$attrs ),
1658
+ "ArrayRef<OpFoldResult>":$outputShape ),
1644
1659
[{
1645
1660
auto reassociationMaps =
1646
- convertReassociationMapsToIndices($_builder, reassociation);
1647
- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1661
+ convertReassociationMapsToIndices(reassociation);
1662
+ build($_builder, $_state, resultType, src, reassociationMaps,
1663
+ outputShape);
1648
1664
}]>,
1649
1665
1666
+ // Builder that infers the result layout map. The result shape must be
1667
+ // specified. Otherwise, the op may be ambiguous. The output shape for
1668
+ // the op will be inferred using the inferOutputShape() method.
1669
+ OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1670
+ "ArrayRef<ReassociationIndices>":$reassociation)>,
1671
+
1650
1672
// Builder that infers the result layout map. The result shape must be
1651
1673
// specified. Otherwise, the op may be ambiguous.
1652
1674
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1653
- "ArrayRef<ReassociationIndices>":$reassociation)>
1675
+ "ArrayRef<ReassociationIndices>":$reassociation,
1676
+ "ArrayRef<OpFoldResult>":$outputShape)>
1654
1677
];
1655
1678
1656
1679
let extraClassDeclaration = commonExtraClassDeclaration # [{
1657
1680
static FailureOr<MemRefType> computeExpandedType(
1658
1681
MemRefType srcType, ArrayRef<int64_t> resultShape,
1659
1682
ArrayRef<ReassociationIndices> reassociation);
1683
+
1684
+ // Infer the output shape for a memref.expand_shape when it is possible
1685
+ // to do so.
1686
+ static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
1687
+ OpBuilder &b, Location loc, MemRefType expandedType,
1688
+ ArrayRef<ReassociationIndices> reassociation,
1689
+ ArrayRef<OpFoldResult> inputShape);
1660
1690
}];
1661
1691
1662
1692
let hasVerifier = 1;
@@ -1707,6 +1737,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1707
1737
source/result layout map are the faster-varying ones.
1708
1738
}];
1709
1739
1740
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
1741
+
1742
+ let assemblyFormat = [{
1743
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
1744
+ }];
1745
+
1710
1746
let builders = [
1711
1747
// Builders for a contracting reshape whose result type is computed from
1712
1748
// `src` and `reassociation`.
@@ -1718,7 +1754,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1718
1754
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1719
1755
[{
1720
1756
auto reassociationMaps =
1721
- convertReassociationMapsToIndices($_builder, reassociation);
1757
+ convertReassociationMapsToIndices(reassociation);
1722
1758
build($_builder, $_state, src, reassociationMaps, attrs);
1723
1759
}]>,
1724
1760
@@ -1736,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1736
1772
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1737
1773
[{
1738
1774
auto reassociationMaps =
1739
- convertReassociationMapsToIndices($_builder, reassociation);
1775
+ convertReassociationMapsToIndices(reassociation);
1740
1776
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1741
1777
}]>
1742
1778
];
0 commit comments