@@ -1548,6 +1548,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
1548
1548
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1549
1549
MemRef_Op<mnemonic, !listconcat(traits,
1550
1550
[Pure, ViewLikeOpInterface])>,
1551
+ Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
1551
1552
Results<(outs AnyStridedMemRef:$result)>{
1552
1553
1553
1554
code commonExtraClassDeclaration = [{
@@ -1572,6 +1573,10 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1572
1573
Value getViewSource() { return getSrc(); }
1573
1574
}];
1574
1575
1576
+ let assemblyFormat = [{
1577
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
1578
+ }];
1579
+
1575
1580
let hasFolder = 1;
1576
1581
let hasCanonicalizer = 1;
1577
1582
let hasVerifier = 1;
@@ -1593,10 +1598,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1593
1598
Example:
1594
1599
1595
1600
```mlir
1596
- %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
1597
- : memref<?x32xf32 > into memref<?x?x32xf32 >
1601
+ %r = memref.expand_shape %0 [[0, 1], [2]]
1602
+ : memref<?x?xf32 > into memref<?x5x?xf32 >
1598
1603
```
1599
1604
1605
+ At most one dimension of a reassociation group (e.g., [0, 1] above) may be
1606
+ dynamic in the result type. Otherwise, the op would be ambiguous, as it
1607
+ would not be clear how the source dimension is extended.
1608
+
1600
1609
If an op can be statically proven to be invalid (e.g, an expansion from
1601
1610
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
1602
1611
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1613,80 +1622,41 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1613
1622
there must be a dynamic result dimension in the corresponding reassociation
1614
1623
group. Same for strides.
1615
1624
1616
- The representation for the output shape supports a partially-static
1617
- specification via attributes specified through the `static_output_shape`
1618
- argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1619
- corresponding entry has a dynamic value. There must be exactly as many SSA
1620
- inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1621
- `static_output_shape`.
1622
-
1623
1625
Note: This op currently assumes that the inner strides are of the
1624
1626
source/result layout map are the faster-varying ones.
1625
1627
}];
1626
1628
1627
- let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
1628
- Variadic<Index>:$output_shape,
1629
- DenseI64ArrayAttr:$static_output_shape);
1630
-
1631
- let assemblyFormat = [{
1632
- $src $reassociation `output_shape`
1633
- custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1634
- type($src) `into` type($result)
1635
- }];
1636
-
1637
1629
let builders = [
1638
1630
// Builders using ReassociationIndices.
1639
1631
OpBuilder<(ins "Type":$resultType, "Value":$src,
1640
1632
"ArrayRef<ReassociationIndices>":$reassociation,
1641
- "ArrayRef<OpFoldResult>":$outputShape)>,
1642
-
1643
- // It will infer output shape using inferOutputShape() method.
1644
- OpBuilder<(ins "Type":$resultType, "Value":$src,
1645
- "ArrayRef<ReassociationIndices>":$reassociation)>,
1646
-
1647
- // Builder using ReassociationExprs.
1648
- OpBuilder<(ins "Type":$resultType, "Value":$src,
1649
- "ArrayRef<ReassociationExprs>":$reassociation),
1633
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1650
1634
[{
1651
- auto reassociationIndices =
1652
- convertReassociationMapsToIndices( reassociation);
1653
- build ($_builder, $_state, resultType, src, reassociationIndices );
1635
+ build($_builder, $_state, resultType, src, attrs);
1636
+ $_state.addAttribute(" reassociation",
1637
+ getReassociationIndicesAttribute ($_builder, reassociation) );
1654
1638
}]>,
1655
1639
1640
+ // Builder using ReassociationExprs.
1656
1641
OpBuilder<(ins "Type":$resultType, "Value":$src,
1657
1642
"ArrayRef<ReassociationExprs>":$reassociation,
1658
- "ArrayRef<OpFoldResult>":$outputShape ),
1643
+ CArg< "ArrayRef<NamedAttribute>", "{}">:$attrs ),
1659
1644
[{
1660
1645
auto reassociationMaps =
1661
- convertReassociationMapsToIndices(reassociation);
1662
- build($_builder, $_state, resultType, src, reassociationMaps,
1663
- outputShape);
1646
+ convertReassociationMapsToIndices($_builder, reassociation);
1647
+ build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1664
1648
}]>,
1665
1649
1666
- // Builder that infers the result layout map. The result shape must be
1667
- // specified. Otherwise, the op may be ambiguous. The output shape for
1668
- // the op will be inferred using the inferOutputShape() method.
1669
- OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1670
- "ArrayRef<ReassociationIndices>":$reassociation)>,
1671
-
1672
1650
// Builder that infers the result layout map. The result shape must be
1673
1651
// specified. Otherwise, the op may be ambiguous.
1674
1652
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1675
- "ArrayRef<ReassociationIndices>":$reassociation,
1676
- "ArrayRef<OpFoldResult>":$outputShape)>
1653
+ "ArrayRef<ReassociationIndices>":$reassociation)>
1677
1654
];
1678
1655
1679
1656
let extraClassDeclaration = commonExtraClassDeclaration # [{
1680
1657
static FailureOr<MemRefType> computeExpandedType(
1681
1658
MemRefType srcType, ArrayRef<int64_t> resultShape,
1682
1659
ArrayRef<ReassociationIndices> reassociation);
1683
-
1684
- // Infer the output shape for a memref.expand_shape when it is possible
1685
- // to do so.
1686
- static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
1687
- OpBuilder &b, Location loc, MemRefType expandedType,
1688
- ArrayRef<ReassociationIndices> reassociation,
1689
- ArrayRef<OpFoldResult> inputShape);
1690
1660
}];
1691
1661
1692
1662
let hasVerifier = 1;
@@ -1737,12 +1707,6 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1737
1707
source/result layout map are the faster-varying ones.
1738
1708
}];
1739
1709
1740
- let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
1741
-
1742
- let assemblyFormat = [{
1743
- $src $reassociation attr-dict `:` type($src) `into` type($result)
1744
- }];
1745
-
1746
1710
let builders = [
1747
1711
// Builders for a contracting reshape whose result type is computed from
1748
1712
// `src` and `reassociation`.
@@ -1754,7 +1718,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1754
1718
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1755
1719
[{
1756
1720
auto reassociationMaps =
1757
- convertReassociationMapsToIndices(reassociation);
1721
+ convertReassociationMapsToIndices($_builder, reassociation);
1758
1722
build($_builder, $_state, src, reassociationMaps, attrs);
1759
1723
}]>,
1760
1724
@@ -1772,7 +1736,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1772
1736
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1773
1737
[{
1774
1738
auto reassociationMaps =
1775
- convertReassociationMapsToIndices(reassociation);
1739
+ convertReassociationMapsToIndices($_builder, reassociation);
1776
1740
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1777
1741
}]>
1778
1742
];
0 commit comments