Skip to content

Commit 8c0341d

Browse files
authored
Revert "[MLIR] Generalize expand_shape to take shape as explicit input" (#89540)
Reverts #69267 this broke some bots.
1 parent e095d97 commit 8c0341d

File tree

52 files changed

+634
-1193
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+634
-1193
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 22 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
15481548
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15491549
MemRef_Op<mnemonic, !listconcat(traits,
15501550
[Pure, ViewLikeOpInterface])>,
1551+
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
15511552
Results<(outs AnyStridedMemRef:$result)>{
15521553

15531554
code commonExtraClassDeclaration = [{
@@ -1572,6 +1573,10 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15721573
Value getViewSource() { return getSrc(); }
15731574
}];
15741575

1576+
let assemblyFormat = [{
1577+
$src $reassociation attr-dict `:` type($src) `into` type($result)
1578+
}];
1579+
15751580
let hasFolder = 1;
15761581
let hasCanonicalizer = 1;
15771582
let hasVerifier = 1;
@@ -1593,10 +1598,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
15931598
Example:
15941599

15951600
```mlir
1596-
%r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
1597-
: memref<?x32xf32> into memref<?x?x32xf32>
1601+
%r = memref.expand_shape %0 [[0, 1], [2]]
1602+
: memref<?x?xf32> into memref<?x5x?xf32>
15981603
```
15991604

1605+
At most one dimension of a reassociation group (e.g., [0, 1] above) may be
1606+
dynamic in the result type. Otherwise, the op would be ambiguous, as it
1607+
would not be clear how the source dimension is extended.
1608+
16001609
If an op can be statically proven to be invalid (e.g, an expansion from
16011610
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
16021611
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1613,80 +1622,41 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16131622
there must be a dynamic result dimension in the corresponding reassociation
16141623
group. Same for strides.
16151624

1616-
The representation for the output shape supports a partially-static
1617-
specification via attributes specified through the `static_output_shape`
1618-
argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1619-
corresponding entry has a dynamic value. There must be exactly as many SSA
1620-
inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1621-
`static_output_shape`.
1622-
16231625
Note: This op currently assumes that the inner strides are of the
16241626
source/result layout map are the faster-varying ones.
16251627
}];
16261628

1627-
let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
1628-
Variadic<Index>:$output_shape,
1629-
DenseI64ArrayAttr:$static_output_shape);
1630-
1631-
let assemblyFormat = [{
1632-
$src $reassociation `output_shape`
1633-
custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1634-
type($src) `into` type($result)
1635-
}];
1636-
16371629
let builders = [
16381630
// Builders using ReassociationIndices.
16391631
OpBuilder<(ins "Type":$resultType, "Value":$src,
16401632
"ArrayRef<ReassociationIndices>":$reassociation,
1641-
"ArrayRef<OpFoldResult>":$outputShape)>,
1642-
1643-
// It will infer output shape using inferOutputShape() method.
1644-
OpBuilder<(ins "Type":$resultType, "Value":$src,
1645-
"ArrayRef<ReassociationIndices>":$reassociation)>,
1646-
1647-
// Builder using ReassociationExprs.
1648-
OpBuilder<(ins "Type":$resultType, "Value":$src,
1649-
"ArrayRef<ReassociationExprs>":$reassociation),
1633+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
16501634
[{
1651-
auto reassociationIndices =
1652-
convertReassociationMapsToIndices(reassociation);
1653-
build($_builder, $_state, resultType, src, reassociationIndices);
1635+
build($_builder, $_state, resultType, src, attrs);
1636+
$_state.addAttribute("reassociation",
1637+
getReassociationIndicesAttribute($_builder, reassociation));
16541638
}]>,
16551639

1640+
// Builder using ReassociationExprs.
16561641
OpBuilder<(ins "Type":$resultType, "Value":$src,
16571642
"ArrayRef<ReassociationExprs>":$reassociation,
1658-
"ArrayRef<OpFoldResult>":$outputShape),
1643+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
16591644
[{
16601645
auto reassociationMaps =
1661-
convertReassociationMapsToIndices(reassociation);
1662-
build($_builder, $_state, resultType, src, reassociationMaps,
1663-
outputShape);
1646+
convertReassociationMapsToIndices($_builder, reassociation);
1647+
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
16641648
}]>,
16651649

1666-
// Builder that infers the result layout map. The result shape must be
1667-
// specified. Otherwise, the op may be ambiguous. The output shape for
1668-
// the op will be inferred using the inferOutputShape() method.
1669-
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1670-
"ArrayRef<ReassociationIndices>":$reassociation)>,
1671-
16721650
// Builder that infers the result layout map. The result shape must be
16731651
// specified. Otherwise, the op may be ambiguous.
16741652
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1675-
"ArrayRef<ReassociationIndices>":$reassociation,
1676-
"ArrayRef<OpFoldResult>":$outputShape)>
1653+
"ArrayRef<ReassociationIndices>":$reassociation)>
16771654
];
16781655

16791656
let extraClassDeclaration = commonExtraClassDeclaration # [{
16801657
static FailureOr<MemRefType> computeExpandedType(
16811658
MemRefType srcType, ArrayRef<int64_t> resultShape,
16821659
ArrayRef<ReassociationIndices> reassociation);
1683-
1684-
// Infer the output shape for a memref.expand_shape when it is possible
1685-
// to do so.
1686-
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
1687-
OpBuilder &b, Location loc, MemRefType expandedType,
1688-
ArrayRef<ReassociationIndices> reassociation,
1689-
ArrayRef<OpFoldResult> inputShape);
16901660
}];
16911661

16921662
let hasVerifier = 1;
@@ -1737,12 +1707,6 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17371707
source/result layout map are the faster-varying ones.
17381708
}];
17391709

1740-
let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
1741-
1742-
let assemblyFormat = [{
1743-
$src $reassociation attr-dict `:` type($src) `into` type($result)
1744-
}];
1745-
17461710
let builders = [
17471711
// Builders for a contracting reshape whose result type is computed from
17481712
// `src` and `reassociation`.
@@ -1754,7 +1718,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17541718
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17551719
[{
17561720
auto reassociationMaps =
1757-
convertReassociationMapsToIndices(reassociation);
1721+
convertReassociationMapsToIndices($_builder, reassociation);
17581722
build($_builder, $_state, src, reassociationMaps, attrs);
17591723
}]>,
17601724

@@ -1772,7 +1736,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17721736
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17731737
[{
17741738
auto reassociationMaps =
1775-
convertReassociationMapsToIndices(reassociation);
1739+
convertReassociationMapsToIndices($_builder, reassociation);
17761740
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
17771741
}]>
17781742
];

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 23 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,8 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
10621062
Tensor_Op<mnemonic, !listconcat(traits, [
10631063
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
10641064
Pure])>,
1065-
Results<(outs AnyTensor:$result)> {
1065+
Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
1066+
Results<(outs AnyRankedTensor:$result)> {
10661067

10671068
code commonExtraClassDeclaration = [{
10681069
static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1085,6 +1086,10 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
10851086
}
10861087
}];
10871088

1089+
let assemblyFormat = [{
1090+
$src $reassociation attr-dict `:` type($src) `into` type($result)
1091+
}];
1092+
10881093
let hasFolder = 1;
10891094
let hasCanonicalizer = 1;
10901095
let hasVerifier = 1;
@@ -1097,83 +1102,50 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
10971102
rank than the operand `src` whose dimension sizes are a reassociation of
10981103
`src`.
10991104

1100-
A reassociation is defined as a continuous grouping of dimensions and is
1101-
represented with an array of DenseI64ArrayAttr attribute. The reassociation
1102-
maps applied to the result tensor with the higher rank must result in the
1103-
operand tensor with the smaller rank.
1105+
A reassociation is defined as a continuous grouping of dimensions. It is
1106+
represented with an array of DenseI64ArrayAttr attribute. Entries in the
1107+
array are referred to as reassociation maps.
11041108

1105-
The representation for the output shape supports a partially-static
1106-
specification via attributes specified through the `static_output_shape`
1107-
argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1108-
corresponding entry has a dynamic value. There must be exactly as many SSA
1109-
inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1110-
`static_output_shape`.
1109+
The reassociation maps are applied to the result shape to obtain the operand
1110+
shape.
11111111

11121112
Example:
11131113

11141114
```mlir
11151115
// Dimension expansion i -> (i', j') and (k) -> (k')
1116-
%b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
1117-
: tensor<?x32xf32> into tensor<?x?x32xf32>
1116+
%b = tensor.expand_shape %a [[0, 1], [2]]
1117+
: tensor<?x?xf32> into tensor<?x?x?xf32>
11181118
```
11191119
}];
1120-
1121-
let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
1122-
Variadic<Index>:$output_shape,
1123-
DenseI64ArrayAttr:$static_output_shape);
1124-
1125-
let assemblyFormat = [{
1126-
$src $reassociation `output_shape`
1127-
custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1128-
type($src) `into` type($result)
1129-
}];
1130-
11311120
let builders = [
11321121
// Builders using ReassociationIndices.
11331122
OpBuilder<(ins "Type":$resultType, "Value":$src,
11341123
"ArrayRef<ReassociationIndices>":$reassociation,
1135-
"ArrayRef<OpFoldResult>":$outputShape)>,
1136-
1137-
// It will infer output shape using inferOutputShape() method.
1138-
OpBuilder<(ins "Type":$resultType, "Value":$src,
1139-
"ArrayRef<ReassociationIndices>":$reassociation)>,
1140-
1141-
// Builder using ReassociationExprs.
1142-
OpBuilder<(ins "Type":$resultType, "Value":$src,
1143-
"ArrayRef<ReassociationExprs>":$reassociation),
1124+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
11441125
[{
1145-
auto reassociationIndices =
1146-
convertReassociationMapsToIndices(reassociation);
1147-
build($_builder, $_state, resultType, src, reassociationIndices);
1126+
build($_builder, $_state, resultType, src, attrs);
1127+
$_state.addAttribute("reassociation",
1128+
getReassociationIndicesAttribute($_builder, reassociation));
11481129
}]>,
11491130
OpBuilder<(ins "Type":$resultType, "Value":$src,
11501131
"ArrayRef<ReassociationExprs>":$reassociation,
1151-
"ArrayRef<OpFoldResult>":$outputShape),
1132+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
11521133
[{
1153-
auto reassociationIndices =
1154-
convertReassociationMapsToIndices(reassociation);
1155-
build($_builder, $_state, resultType, src, reassociationIndices,
1156-
outputShape);
1134+
auto reassociationMaps =
1135+
convertReassociationMapsToIndices($_builder, reassociation);
1136+
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
11571137
}]>
11581138
];
11591139

11601140
let extraClassDeclaration = commonExtraClassDeclaration # [{
11611141
int64_t getCorrespondingSourceDim(int64_t resultDim);
1162-
1163-
// Infer the output shape for a tensor.expand_shape when it is possible
1164-
// to do so.
1165-
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
1166-
OpBuilder &b, Location loc, RankedTensorType expandedType,
1167-
ArrayRef<ReassociationIndices> reassociation,
1168-
ArrayRef<OpFoldResult> inputShape);
11691142
}];
11701143

11711144
let hasVerifier = 1;
11721145
}
11731146

11741147
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11751148
let summary = "operation to produce a tensor with a smaller rank";
1176-
let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
11771149
let description = [{
11781150
The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
11791151
rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1191,11 +1163,6 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11911163
: tensor<?x?x?xf32> into tensor<?x?xf32>
11921164
```
11931165
}];
1194-
1195-
let assemblyFormat = [{
1196-
$src $reassociation attr-dict `:` type($src) `into` type($result)
1197-
}];
1198-
11991166
let builders = [
12001167
// Builders for a contracting reshape whose result type is computed from
12011168
// `src` and `reassociation`.
@@ -1207,7 +1174,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
12071174
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
12081175
[{
12091176
auto reassociationMaps =
1210-
convertReassociationMapsToIndices(reassociation);
1177+
convertReassociationMapsToIndices($_builder, reassociation);
12111178
build($_builder, $_state, src, reassociationMaps, attrs);
12121179
}]>,
12131180

@@ -1225,7 +1192,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
12251192
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
12261193
[{
12271194
auto reassociationMaps =
1228-
convertReassociationMapsToIndices(reassociation);
1195+
convertReassociationMapsToIndices($_builder, reassociation);
12291196
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
12301197
}]>
12311198
];

0 commit comments

Comments
 (0)