Skip to content

Commit 273d71d

Browse files
ramiro050Shukla-Gaurav
authored andcommitted
[MLIR] Generalize expand_shape to take shape as explicit input
*DO NOT SUBMIT* (This patch is for early design feedback only. Notably, tests have not been updated and the implementation is incomplete in some cases.) This patch generalizes tensor.expand_shape and memref.expand_shape to consume the output shape as a list of SSA values. This enables us to implement generic reshape operations with dynamic shapes using collapse_shape/expand_shape pairs. The output_shape input to expand_shape follows the static/dynamic representation that's also used in `tensor.extract_slice`. Differential Revision: https://reviews.llvm.org/D140821
1 parent 206799f commit 273d71d

File tree

18 files changed

+432
-122
lines changed

18 files changed

+432
-122
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
15541554
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15551555
MemRef_Op<mnemonic, !listconcat(traits,
15561556
[Pure, ViewLikeOpInterface])>,
1557-
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
15581557
Results<(outs AnyStridedMemRef:$result)>{
15591558

15601559
code commonExtraClassDeclaration = [{
@@ -1579,10 +1578,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15791578
Value getViewSource() { return getSrc(); }
15801579
}];
15811580

1582-
let assemblyFormat = [{
1583-
$src $reassociation attr-dict `:` type($src) `into` type($result)
1584-
}];
1585-
15861581
let hasFolder = 1;
15871582
let hasCanonicalizer = 1;
15881583
let hasVerifier = 1;
@@ -1604,14 +1599,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16041599
Example:
16051600

16061601
```mlir
1607-
%r = memref.expand_shape %0 [[0, 1], [2]]
1608-
: memref<?x?xf32> into memref<?x5x?xf32>
1602+
%r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32]
1603+
: memref<?x32xf32> into memref<?x?x32xf32>
16091604
```
16101605

1611-
At most one dimension of a reassociation group (e.g., [0, 1] above) may be
1612-
dynamic in the result type. Otherwise, the op would be ambiguous, as it
1613-
would not be clear how the source dimension is extended.
1614-
16151606
If an op can be statically proven to be invalid (e.g, an expansion from
16161607
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
16171608
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1628,29 +1619,72 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16281619
there must be a dynamic result dimension in the corresponding reassociation
16291620
group. Same for strides.
16301621

1622+
The representation for the output shape supports a partially-static
1623+
specification via attributes specified through the `static_output_shape`
1624+
argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1625+
corresponding entry has a dynamic value. There must be exactly as many SSA
1626+
inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1627+
`static_output_shape`.
1628+
16311629
Note: This op currently assumes that the inner strides are of the
16321630
source/result layout map are the faster-varying ones.
16331631
}];
16341632

1633+
let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
1634+
Variadic<Index>:$output_shape,
1635+
DenseI64ArrayAttr:$static_output_shape);
1636+
1637+
let assemblyFormat = [{
1638+
$src $reassociation `output_shape`
1639+
custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1640+
type($src) `into` type($result)
1641+
}];
1642+
16351643
let builders = [
16361644
// Builders using ReassociationIndices.
1645+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1646+
"ArrayRef<ReassociationIndices>":$reassociation),
1647+
[{
1648+
SmallVector<OpFoldResult> inputShape =
1649+
getMixedSizes($_builder, $_state.location, src);
1650+
std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
1651+
auto status =
1652+
inferOutputShape($_builder, $_state.location,
1653+
resultType.cast<MemRefType>(),
1654+
reassociation, inputShape, outputShape);
1655+
(void) status;
1656+
assert(succeeded(status) && "unable to infer output shape");
1657+
build($_builder, $_state, resultType.cast<MemRefType>(), src,
1658+
getReassociationIndicesAttribute($_builder, reassociation),
1659+
outputShape.second, outputShape.first);
1660+
}]>,
16371661
OpBuilder<(ins "Type":$resultType, "Value":$src,
16381662
"ArrayRef<ReassociationIndices>":$reassociation,
1639-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1663+
"ArrayRef<OpFoldResult>":$outputShape),
16401664
[{
1641-
build($_builder, $_state, resultType, src, attrs);
1642-
$_state.addAttribute("reassociation",
1643-
getReassociationIndicesAttribute($_builder, reassociation));
1665+
auto [staticOutputShape, dynamicOutputShape] =
1666+
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1667+
build($_builder, $_state, resultType, src,
1668+
getReassociationIndicesAttribute($_builder, reassociation),
1669+
dynamicOutputShape, staticOutputShape);
16441670
}]>,
16451671

16461672
// Builder using ReassociationExprs.
1673+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1674+
"ArrayRef<ReassociationExprs>":$reassociation),
1675+
[{
1676+
auto reassociationIndices =
1677+
convertReassociationMapsToIndices(reassociation);
1678+
build($_builder, $_state, resultType, src, reassociationIndices);
1679+
}]>,
16471680
OpBuilder<(ins "Type":$resultType, "Value":$src,
16481681
"ArrayRef<ReassociationExprs>":$reassociation,
1649-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1682+
"ArrayRef<OpFoldResult>":$outputShape),
16501683
[{
16511684
auto reassociationMaps =
1652-
convertReassociationMapsToIndices($_builder, reassociation);
1653-
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1685+
convertReassociationMapsToIndices(reassociation);
1686+
build($_builder, $_state, resultType, src, reassociationMaps,
1687+
outputShape);
16541688
}]>,
16551689

16561690
// Builder that infers the result layout map. The result shape must be
@@ -1663,6 +1697,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16631697
static FailureOr<MemRefType> computeExpandedType(
16641698
MemRefType srcType, ArrayRef<int64_t> resultShape,
16651699
ArrayRef<ReassociationIndices> reassociation);
1700+
1701+
// Infer the output shape for a memref.expand_shape when it is possible
1702+
// to do so.
1703+
static LogicalResult inferOutputShape(
1704+
OpBuilder &b, Location loc, MemRefType expandedType,
1705+
ArrayRef<ReassociationIndices> reassociation,
1706+
ArrayRef<OpFoldResult> inputShape,
1707+
std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
16661708
}];
16671709

16681710
let hasVerifier = 1;
@@ -1713,6 +1755,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17131755
source/result layout map are the faster-varying ones.
17141756
}];
17151757

1758+
let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
1759+
1760+
let assemblyFormat = [{
1761+
$src $reassociation attr-dict `:` type($src) `into` type($result)
1762+
}];
1763+
17161764
let builders = [
17171765
// Builders for a contracting reshape whose result type is computed from
17181766
// `src` and `reassociation`.
@@ -1724,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17241772
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17251773
[{
17261774
auto reassociationMaps =
1727-
convertReassociationMapsToIndices($_builder, reassociation);
1775+
convertReassociationMapsToIndices(reassociation);
17281776
build($_builder, $_state, src, reassociationMaps, attrs);
17291777
}]>,
17301778

@@ -1742,7 +1790,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17421790
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17431791
[{
17441792
auto reassociationMaps =
1745-
convertReassociationMapsToIndices($_builder, reassociation);
1793+
convertReassociationMapsToIndices(reassociation);
17461794
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
17471795
}]>
17481796
];

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -998,8 +998,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
998998
Tensor_Op<mnemonic, !listconcat(traits, [
999999
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
10001000
Pure])>,
1001-
Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
1002-
Results<(outs AnyRankedTensor:$result)> {
1001+
Results<(outs AnyTensor:$result)> {
10031002

10041003
code commonExtraClassDeclaration = [{
10051004
static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1022,10 +1021,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
10221021
}
10231022
}];
10241023

1025-
let assemblyFormat = [{
1026-
$src $reassociation attr-dict `:` type($src) `into` type($result)
1027-
}];
1028-
10291024
let hasFolder = 1;
10301025
let hasCanonicalizer = 1;
10311026
let hasVerifier = 1;
@@ -1038,11 +1033,16 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
10381033
rank whose sizes are a reassociation of the original `src`.
10391034

10401035
A reassociation is defined as a continuous grouping of dimensions and is
1041-
represented with an array of DenseI64ArrayAttr attribute.
1036+
represented with an array of DenseI64ArrayAttr attribute. The reassociation
1037+
maps applied to the result tensor with the higher rank must result in the
1038+
operand tensor with the smaller rank.
10421039

1043-
The verification rule is that the reassociation maps are applied to the
1044-
result tensor with the higher rank to obtain the operand tensor with the
1045-
smaller rank.
1040+
The representation for the output shape supports a partially-static
1041+
specification via attributes specified through the `static_output_shape`
1042+
argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1043+
corresponding entry has a dynamic value. There must be exactly as many SSA
1044+
inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1045+
`static_output_shape`.
10461046

10471047
The operand tensor type of a reshape can be zero-ranked if the result
10481048
tensor type is statically shaped with all dimensions being unit extent. In
@@ -1052,39 +1052,87 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
10521052

10531053
```mlir
10541054
// Dimension expansion i -> (i', j') and (k) -> (k')
1055-
%b = tensor.expand_shape %a [[0, 1], [2]]
1056-
: tensor<?x?xf32> into tensor<?x?x?xf32>
1055+
%b = tensor.expand_shape %a [[0, 1], [2]] [%sz0, %sz1, 32]
1056+
: tensor<?x32xf32> into tensor<?x?x32xf32>
10571057
```
10581058
}];
1059+
1060+
let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
1061+
Variadic<Index>:$output_shape,
1062+
DenseI64ArrayAttr:$static_output_shape);
1063+
1064+
let assemblyFormat = [{
1065+
$src $reassociation `output_shape`
1066+
custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1067+
type($src) `into` type($result)
1068+
}];
1069+
10591070
let builders = [
10601071
// Builders using ReassociationIndices.
1072+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1073+
"ArrayRef<ReassociationIndices>":$reassociation),
1074+
[{
1075+
SmallVector<OpFoldResult> inputShape =
1076+
getMixedSizes($_builder, $_state.location, src);
1077+
std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
1078+
auto status =
1079+
inferOutputShape($_builder, $_state.location,
1080+
resultType.cast<RankedTensorType>(),
1081+
reassociation, inputShape, outputShape);
1082+
(void) status;
1083+
assert(succeeded(status) && "unable to infer output shape");
1084+
build($_builder, $_state, resultType.cast<RankedTensorType>(), src,
1085+
getReassociationIndicesAttribute($_builder, reassociation),
1086+
outputShape.second, outputShape.first);
1087+
}]>,
10611088
OpBuilder<(ins "Type":$resultType, "Value":$src,
10621089
"ArrayRef<ReassociationIndices>":$reassociation,
1063-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1090+
"ArrayRef<OpFoldResult>":$outputShape),
10641091
[{
1065-
build($_builder, $_state, resultType, src, attrs);
1066-
$_state.addAttribute("reassociation",
1067-
getReassociationIndicesAttribute($_builder, reassociation));
1092+
auto [staticOutputShape, dynamicOutputShape] =
1093+
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1094+
build($_builder, $_state, resultType, src,
1095+
getReassociationIndicesAttribute($_builder, reassociation),
1096+
dynamicOutputShape, staticOutputShape);
1097+
}]>,
1098+
1099+
// Builder using ReassociationExprs.
1100+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1101+
"ArrayRef<ReassociationExprs>":$reassociation),
1102+
[{
1103+
auto reassociationIndices =
1104+
convertReassociationMapsToIndices(reassociation);
1105+
build($_builder, $_state, resultType, src, reassociationIndices);
10681106
}]>,
10691107
OpBuilder<(ins "Type":$resultType, "Value":$src,
10701108
"ArrayRef<ReassociationExprs>":$reassociation,
1071-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1109+
"ArrayRef<OpFoldResult>":$outputShape),
10721110
[{
1073-
auto reassociationMaps =
1074-
convertReassociationMapsToIndices($_builder, reassociation);
1075-
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1111+
auto reassociationIndices =
1112+
convertReassociationMapsToIndices(reassociation);
1113+
build($_builder, $_state, resultType, src, reassociationIndices,
1114+
outputShape);
10761115
}]>
10771116
];
10781117

10791118
let extraClassDeclaration = commonExtraClassDeclaration # [{
10801119
int64_t getCorrespondingSourceDim(int64_t resultDim);
1120+
1121+
// Infer the output shape for a tensor.expand_shape when it is possible
1122+
// to do so.
1123+
static LogicalResult inferOutputShape(
1124+
OpBuilder &b, Location loc, RankedTensorType expandedType,
1125+
ArrayRef<ReassociationIndices> reassociation,
1126+
ArrayRef<OpFoldResult> inputShape,
1127+
std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
10811128
}];
10821129

10831130
let hasVerifier = 1;
10841131
}
10851132

10861133
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
10871134
let summary = "operation to produce a tensor with a smaller rank";
1135+
let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
10881136
let description = [{
10891137
The `tensor.collapse_shape` op produces a new tensor with a smaller
10901138
rank whose sizes are a reassociation of the original `src`.
@@ -1108,6 +1156,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11081156
: tensor<?x?x?xf32> into tensor<?x?xf32>
11091157
```
11101158
}];
1159+
1160+
let assemblyFormat = [{
1161+
$src $reassociation attr-dict `:` type($src) `into` type($result)
1162+
}];
1163+
11111164
let builders = [
11121165
// Builders for a contracting reshape whose result type is computed from
11131166
// `src` and `reassociation`.
@@ -1119,7 +1172,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11191172
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
11201173
[{
11211174
auto reassociationMaps =
1122-
convertReassociationMapsToIndices($_builder, reassociation);
1175+
convertReassociationMapsToIndices(reassociation);
11231176
build($_builder, $_state, src, reassociationMaps, attrs);
11241177
}]>,
11251178

@@ -1137,7 +1190,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11371190
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
11381191
[{
11391192
auto reassociationMaps =
1140-
convertReassociationMapsToIndices($_builder, reassociation);
1193+
convertReassociationMapsToIndices(reassociation);
11411194
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
11421195
}]>
11431196
];

0 commit comments

Comments
 (0)