Skip to content

Commit 24ee8aa

Browse files
ramiro050Shukla-Gaurav
authored andcommitted
[MLIR] Generalize expand_shape to take shape as explicit input
*DO NOT SUBMIT* (This patch needs to be tested in IREE backend) This patch generalizes tensor.expand_shape and memref.expand_shape to consume the output shape as a list of SSA values. This enables us to implement generic reshape operations with dynamic shapes using collapse_shape/expand_shape pairs. The output_shape input to expand_shape follows the static/dynamic representation that's also used in `tensor.extract_slice`. Differential Revision: https://reviews.llvm.org/D140821
1 parent 986435c commit 24ee8aa

File tree

17 files changed

+420
-116
lines changed

17 files changed

+420
-116
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
15481548
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15491549
MemRef_Op<mnemonic, !listconcat(traits,
15501550
[Pure, ViewLikeOpInterface])>,
1551-
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
15521551
Results<(outs AnyStridedMemRef:$result)>{
15531552

15541553
code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15731572
Value getViewSource() { return getSrc(); }
15741573
}];
15751574

1576-
let assemblyFormat = [{
1577-
$src $reassociation attr-dict `:` type($src) `into` type($result)
1578-
}];
1579-
15801575
let hasFolder = 1;
15811576
let hasCanonicalizer = 1;
15821577
let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
15981593
Example:
15991594

16001595
```mlir
1601-
%r = memref.expand_shape %0 [[0, 1], [2]]
1602-
: memref<?x?xf32> into memref<?x5x?xf32>
1596+
%r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32]
1597+
: memref<?x32xf32> into memref<?x?x32xf32>
16031598
```
16041599

1605-
At most one dimension of a reassociation group (e.g., [0, 1] above) may be
1606-
dynamic in the result type. Otherwise, the op would be ambiguous, as it
1607-
would not be clear how the source dimension is extended.
1608-
16091600
If an op can be statically proven to be invalid (e.g, an expansion from
16101601
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
16111602
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,29 +1613,72 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16221613
there must be a dynamic result dimension in the corresponding reassociation
16231614
group. Same for strides.
16241615

1616+
The representation for the output shape supports a partially-static
1617+
specification via attributes specified through the `static_output_shape`
1618+
argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1619+
corresponding entry has a dynamic value. There must be exactly as many SSA
1620+
inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1621+
`static_output_shape`.
1622+
16251623
Note: This op currently assumes that the inner strides are of the
16261624
source/result layout map are the faster-varying ones.
16271625
}];
16281626

1627+
let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
1628+
Variadic<Index>:$output_shape,
1629+
DenseI64ArrayAttr:$static_output_shape);
1630+
1631+
let assemblyFormat = [{
1632+
$src $reassociation `output_shape`
1633+
custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1634+
type($src) `into` type($result)
1635+
}];
1636+
16291637
let builders = [
16301638
// Builders using ReassociationIndices.
1639+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1640+
"ArrayRef<ReassociationIndices>":$reassociation),
1641+
[{
1642+
SmallVector<OpFoldResult> inputShape =
1643+
getMixedSizes($_builder, $_state.location, src);
1644+
std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
1645+
auto status =
1646+
inferOutputShape($_builder, $_state.location,
1647+
resultType.cast<MemRefType>(),
1648+
reassociation, inputShape, outputShape);
1649+
(void) status;
1650+
assert(succeeded(status) && "unable to infer output shape");
1651+
build($_builder, $_state, resultType.cast<MemRefType>(), src,
1652+
getReassociationIndicesAttribute($_builder, reassociation),
1653+
outputShape.second, outputShape.first);
1654+
}]>,
16311655
OpBuilder<(ins "Type":$resultType, "Value":$src,
16321656
"ArrayRef<ReassociationIndices>":$reassociation,
1633-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1657+
"ArrayRef<OpFoldResult>":$outputShape),
16341658
[{
1635-
build($_builder, $_state, resultType, src, attrs);
1636-
$_state.addAttribute("reassociation",
1637-
getReassociationIndicesAttribute($_builder, reassociation));
1659+
auto [staticOutputShape, dynamicOutputShape] =
1660+
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1661+
build($_builder, $_state, resultType, src,
1662+
getReassociationIndicesAttribute($_builder, reassociation),
1663+
dynamicOutputShape, staticOutputShape);
16381664
}]>,
16391665

16401666
// Builder using ReassociationExprs.
1667+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1668+
"ArrayRef<ReassociationExprs>":$reassociation),
1669+
[{
1670+
auto reassociationIndices =
1671+
convertReassociationMapsToIndices(reassociation);
1672+
build($_builder, $_state, resultType, src, reassociationIndices);
1673+
}]>,
16411674
OpBuilder<(ins "Type":$resultType, "Value":$src,
16421675
"ArrayRef<ReassociationExprs>":$reassociation,
1643-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1676+
"ArrayRef<OpFoldResult>":$outputShape),
16441677
[{
16451678
auto reassociationMaps =
1646-
convertReassociationMapsToIndices($_builder, reassociation);
1647-
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1679+
convertReassociationMapsToIndices(reassociation);
1680+
build($_builder, $_state, resultType, src, reassociationMaps,
1681+
outputShape);
16481682
}]>,
16491683

16501684
// Builder that infers the result layout map. The result shape must be
@@ -1657,6 +1691,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16571691
static FailureOr<MemRefType> computeExpandedType(
16581692
MemRefType srcType, ArrayRef<int64_t> resultShape,
16591693
ArrayRef<ReassociationIndices> reassociation);
1694+
1695+
// Infer the output shape for a memref.expand_shape when it is possible
1696+
// to do so.
1697+
static LogicalResult inferOutputShape(
1698+
OpBuilder &b, Location loc, MemRefType expandedType,
1699+
ArrayRef<ReassociationIndices> reassociation,
1700+
ArrayRef<OpFoldResult> inputShape,
1701+
std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
16601702
}];
16611703

16621704
let hasVerifier = 1;
@@ -1707,6 +1749,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17071749
source/result layout map are the faster-varying ones.
17081750
}];
17091751

1752+
let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
1753+
1754+
let assemblyFormat = [{
1755+
$src $reassociation attr-dict `:` type($src) `into` type($result)
1756+
}];
1757+
17101758
let builders = [
17111759
// Builders for a contracting reshape whose result type is computed from
17121760
// `src` and `reassociation`.
@@ -1718,7 +1766,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17181766
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17191767
[{
17201768
auto reassociationMaps =
1721-
convertReassociationMapsToIndices($_builder, reassociation);
1769+
convertReassociationMapsToIndices(reassociation);
17221770
build($_builder, $_state, src, reassociationMaps, attrs);
17231771
}]>,
17241772

@@ -1736,7 +1784,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17361784
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17371785
[{
17381786
auto reassociationMaps =
1739-
convertReassociationMapsToIndices($_builder, reassociation);
1787+
convertReassociationMapsToIndices(reassociation);
17401788
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
17411789
}]>
17421790
];

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
10621062
Tensor_Op<mnemonic, !listconcat(traits, [
10631063
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
10641064
Pure])>,
1065-
Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
1066-
Results<(outs AnyRankedTensor:$result)> {
1065+
Results<(outs AnyTensor:$result)> {
10671066

10681067
code commonExtraClassDeclaration = [{
10691068
static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
10861085
}
10871086
}];
10881087

1089-
let assemblyFormat = [{
1090-
$src $reassociation attr-dict `:` type($src) `into` type($result)
1091-
}];
1092-
10931088
let hasFolder = 1;
10941089
let hasCanonicalizer = 1;
10951090
let hasVerifier = 1;
@@ -1102,50 +1097,103 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
11021097
rank than the operand `src` whose dimension sizes are a reassociation of
11031098
`src`.
11041099

1105-
A reassociation is defined as a continuous grouping of dimensions. It is
1106-
represented with an array of DenseI64ArrayAttr attribute. Entries in the
1107-
array are referred to as reassociation maps.
1100+
A reassociation is defined as a continuous grouping of dimensions and is
1101+
represented with an array of DenseI64ArrayAttr attribute. The reassociation
1102+
maps applied to the result tensor with the higher rank must result in the
1103+
operand tensor with the smaller rank.
11081104

1109-
The reassociation maps are applied to the result shape to obtain the operand
1110-
shape.
1105+
The representation for the output shape supports a partially-static
1106+
specification via attributes specified through the `static_output_shape`
1107+
argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1108+
corresponding entry has a dynamic value. There must be exactly as many SSA
1109+
inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1110+
`static_output_shape`.
11111111

11121112
Example:
11131113

11141114
```mlir
11151115
// Dimension expansion i -> (i', j') and (k) -> (k')
1116-
%b = tensor.expand_shape %a [[0, 1], [2]]
1117-
: tensor<?x?xf32> into tensor<?x?x?xf32>
1116+
%b = tensor.expand_shape %a [[0, 1], [2]] [%sz0, %sz1, 32]
1117+
: tensor<?x32xf32> into tensor<?x?x32xf32>
11181118
```
11191119
}];
1120+
1121+
let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
1122+
Variadic<Index>:$output_shape,
1123+
DenseI64ArrayAttr:$static_output_shape);
1124+
1125+
let assemblyFormat = [{
1126+
$src $reassociation `output_shape`
1127+
custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1128+
type($src) `into` type($result)
1129+
}];
1130+
11201131
let builders = [
11211132
// Builders using ReassociationIndices.
1133+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1134+
"ArrayRef<ReassociationIndices>":$reassociation),
1135+
[{
1136+
SmallVector<OpFoldResult> inputShape =
1137+
getMixedSizes($_builder, $_state.location, src);
1138+
std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
1139+
auto status =
1140+
inferOutputShape($_builder, $_state.location,
1141+
resultType.cast<RankedTensorType>(),
1142+
reassociation, inputShape, outputShape);
1143+
(void) status;
1144+
assert(succeeded(status) && "unable to infer output shape");
1145+
build($_builder, $_state, resultType.cast<RankedTensorType>(), src,
1146+
getReassociationIndicesAttribute($_builder, reassociation),
1147+
outputShape.second, outputShape.first);
1148+
}]>,
11221149
OpBuilder<(ins "Type":$resultType, "Value":$src,
11231150
"ArrayRef<ReassociationIndices>":$reassociation,
1124-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1151+
"ArrayRef<OpFoldResult>":$outputShape),
11251152
[{
1126-
build($_builder, $_state, resultType, src, attrs);
1127-
$_state.addAttribute("reassociation",
1128-
getReassociationIndicesAttribute($_builder, reassociation));
1153+
auto [staticOutputShape, dynamicOutputShape] =
1154+
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1155+
build($_builder, $_state, resultType, src,
1156+
getReassociationIndicesAttribute($_builder, reassociation),
1157+
dynamicOutputShape, staticOutputShape);
1158+
}]>,
1159+
1160+
// Builder using ReassociationExprs.
1161+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1162+
"ArrayRef<ReassociationExprs>":$reassociation),
1163+
[{
1164+
auto reassociationIndices =
1165+
convertReassociationMapsToIndices(reassociation);
1166+
build($_builder, $_state, resultType, src, reassociationIndices);
11291167
}]>,
11301168
OpBuilder<(ins "Type":$resultType, "Value":$src,
11311169
"ArrayRef<ReassociationExprs>":$reassociation,
1132-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1170+
"ArrayRef<OpFoldResult>":$outputShape),
11331171
[{
1134-
auto reassociationMaps =
1135-
convertReassociationMapsToIndices($_builder, reassociation);
1136-
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1172+
auto reassociationIndices =
1173+
convertReassociationMapsToIndices(reassociation);
1174+
build($_builder, $_state, resultType, src, reassociationIndices,
1175+
outputShape);
11371176
}]>
11381177
];
11391178

11401179
let extraClassDeclaration = commonExtraClassDeclaration # [{
11411180
int64_t getCorrespondingSourceDim(int64_t resultDim);
1181+
1182+
// Infer the output shape for a tensor.expand_shape when it is possible
1183+
// to do so.
1184+
static LogicalResult inferOutputShape(
1185+
OpBuilder &b, Location loc, RankedTensorType expandedType,
1186+
ArrayRef<ReassociationIndices> reassociation,
1187+
ArrayRef<OpFoldResult> inputShape,
1188+
std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
11421189
}];
11431190

11441191
let hasVerifier = 1;
11451192
}
11461193

11471194
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11481195
let summary = "operation to produce a tensor with a smaller rank";
1196+
let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
11491197
let description = [{
11501198
The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
11511199
rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1163,6 +1211,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11631211
: tensor<?x?x?xf32> into tensor<?x?xf32>
11641212
```
11651213
}];
1214+
1215+
let assemblyFormat = [{
1216+
$src $reassociation attr-dict `:` type($src) `into` type($result)
1217+
}];
1218+
11661219
let builders = [
11671220
// Builders for a contracting reshape whose result type is computed from
11681221
// `src` and `reassociation`.
@@ -1174,7 +1227,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11741227
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
11751228
[{
11761229
auto reassociationMaps =
1177-
convertReassociationMapsToIndices($_builder, reassociation);
1230+
convertReassociationMapsToIndices(reassociation);
11781231
build($_builder, $_state, src, reassociationMaps, attrs);
11791232
}]>,
11801233

@@ -1192,7 +1245,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11921245
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
11931246
[{
11941247
auto reassociationMaps =
1195-
convertReassociationMapsToIndices($_builder, reassociation);
1248+
convertReassociationMapsToIndices(reassociation);
11961249
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
11971250
}]>
11981251
];

0 commit comments

Comments
 (0)