Skip to content

Commit 97069a8

Browse files
[MLIR] Generalize expand_shape to take shape as explicit input (#90040)
This patch generalizes tensor.expand_shape and memref.expand_shape to consume the output shape as a list of SSA values. This enables us to implement generic reshape operations with dynamic shapes using collapse_shape/expand_shape pairs. The output_shape input to expand_shape follows the static/dynamic representation that's also used in `tensor.extract_slice`. Differential Revision: https://reviews.llvm.org/D140821 --------- Signed-off-by: Gaurav Shukla<[email protected]> Signed-off-by: Gaurav Shukla <[email protected]> Co-authored-by: Ramiro Leal-Cavazos <[email protected]>
1 parent 539f626 commit 97069a8

File tree

55 files changed

+1214
-633
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1214
-633
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,29 @@
2424

2525
namespace mlir {
2626

27+
using ReassociationIndices = SmallVector<int64_t, 2>;
28+
29+
/// Infer the output shape for a {memref|tensor}.expand_shape when it is
30+
/// possible to do so.
31+
///
32+
/// Note: This should *only* be used to implement
33+
/// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
34+
/// If you need to infer the output shape you should use the static method of
35+
/// `ExpandShapeOp` instead of calling this.
36+
///
37+
/// `inputShape` is the shape of the tensor or memref being expanded as a
38+
/// sequence of SSA values or constants. `expandedType` is the output shape of
39+
/// the expand_shape operation. `reassociation` is the reassociation denoting
40+
/// the output dims each input dim is mapped to.
41+
///
42+
/// Returns the output shape in `outputShape` and `staticOutputShape`, following
43+
/// the conventions for the output_shape and static_output_shape inputs to the
44+
/// expand_shape ops.
45+
std::optional<SmallVector<OpFoldResult>>
46+
inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
47+
ArrayRef<ReassociationIndices> reassociation,
48+
ArrayRef<OpFoldResult> inputShape);
49+
2750
/// Matches a ConstantIndexOp.
2851
detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
2952

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
15481548
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15491549
MemRef_Op<mnemonic, !listconcat(traits,
15501550
[Pure, ViewLikeOpInterface])>,
1551-
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
15521551
Results<(outs AnyStridedMemRef:$result)>{
15531552

15541553
code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15731572
Value getViewSource() { return getSrc(); }
15741573
}];
15751574

1576-
let assemblyFormat = [{
1577-
$src $reassociation attr-dict `:` type($src) `into` type($result)
1578-
}];
1579-
15801575
let hasFolder = 1;
15811576
let hasCanonicalizer = 1;
15821577
let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
15981593
Example:
15991594

16001595
```mlir
1601-
%r = memref.expand_shape %0 [[0, 1], [2]]
1602-
: memref<?x?xf32> into memref<?x5x?xf32>
1596+
%r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
1597+
: memref<?x32xf32> into memref<?x?x32xf32>
16031598
```
16041599

1605-
At most one dimension of a reassociation group (e.g., [0, 1] above) may be
1606-
dynamic in the result type. Otherwise, the op would be ambiguous, as it
1607-
would not be clear how the source dimension is extended.
1608-
16091600
If an op can be statically proven to be invalid (e.g, an expansion from
16101601
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
16111602
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,41 +1613,80 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16221613
there must be a dynamic result dimension in the corresponding reassociation
16231614
group. Same for strides.
16241615

1616+
The representation for the output shape supports a partially-static
1617+
specification via attributes specified through the `static_output_shape`
1618+
argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1619+
corresponding entry has a dynamic value. There must be exactly as many SSA
1620+
inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1621+
`static_output_shape`.
1622+
16251623
Note: This op currently assumes that the inner strides are of the
16261624
source/result layout map are the faster-varying ones.
16271625
}];
16281626

1627+
let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
1628+
Variadic<Index>:$output_shape,
1629+
DenseI64ArrayAttr:$static_output_shape);
1630+
1631+
let assemblyFormat = [{
1632+
$src $reassociation `output_shape`
1633+
custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1634+
type($src) `into` type($result)
1635+
}];
1636+
16291637
let builders = [
16301638
// Builders using ReassociationIndices.
16311639
OpBuilder<(ins "Type":$resultType, "Value":$src,
16321640
"ArrayRef<ReassociationIndices>":$reassociation,
1633-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1641+
"ArrayRef<OpFoldResult>":$outputShape)>,
1642+
1643+
// It will infer output shape using inferOutputShape() method.
1644+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1645+
"ArrayRef<ReassociationIndices>":$reassociation)>,
1646+
1647+
// Builder using ReassociationExprs.
1648+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1649+
"ArrayRef<ReassociationExprs>":$reassociation),
16341650
[{
1635-
build($_builder, $_state, resultType, src, attrs);
1636-
$_state.addAttribute("reassociation",
1637-
getReassociationIndicesAttribute($_builder, reassociation));
1651+
auto reassociationIndices =
1652+
convertReassociationMapsToIndices(reassociation);
1653+
build($_builder, $_state, resultType, src, reassociationIndices);
16381654
}]>,
16391655

1640-
// Builder using ReassociationExprs.
16411656
OpBuilder<(ins "Type":$resultType, "Value":$src,
16421657
"ArrayRef<ReassociationExprs>":$reassociation,
1643-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1658+
"ArrayRef<OpFoldResult>":$outputShape),
16441659
[{
16451660
auto reassociationMaps =
1646-
convertReassociationMapsToIndices($_builder, reassociation);
1647-
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1661+
convertReassociationMapsToIndices(reassociation);
1662+
build($_builder, $_state, resultType, src, reassociationMaps,
1663+
outputShape);
16481664
}]>,
16491665

1666+
// Builder that infers the result layout map. The result shape must be
1667+
// specified. Otherwise, the op may be ambiguous. The output shape for
1668+
// the op will be inferred using the inferOutputShape() method.
1669+
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1670+
"ArrayRef<ReassociationIndices>":$reassociation)>,
1671+
16501672
// Builder that infers the result layout map. The result shape must be
16511673
// specified. Otherwise, the op may be ambiguous.
16521674
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1653-
"ArrayRef<ReassociationIndices>":$reassociation)>
1675+
"ArrayRef<ReassociationIndices>":$reassociation,
1676+
"ArrayRef<OpFoldResult>":$outputShape)>
16541677
];
16551678

16561679
let extraClassDeclaration = commonExtraClassDeclaration # [{
16571680
static FailureOr<MemRefType> computeExpandedType(
16581681
MemRefType srcType, ArrayRef<int64_t> resultShape,
16591682
ArrayRef<ReassociationIndices> reassociation);
1683+
1684+
// Infer the output shape for a memref.expand_shape when it is possible
1685+
// to do so.
1686+
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
1687+
OpBuilder &b, Location loc, MemRefType expandedType,
1688+
ArrayRef<ReassociationIndices> reassociation,
1689+
ArrayRef<OpFoldResult> inputShape);
16601690
}];
16611691

16621692
let hasVerifier = 1;
@@ -1707,6 +1737,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17071737
source/result layout map are the faster-varying ones.
17081738
}];
17091739

1740+
let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
1741+
1742+
let assemblyFormat = [{
1743+
$src $reassociation attr-dict `:` type($src) `into` type($result)
1744+
}];
1745+
17101746
let builders = [
17111747
// Builders for a contracting reshape whose result type is computed from
17121748
// `src` and `reassociation`.
@@ -1718,7 +1754,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17181754
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17191755
[{
17201756
auto reassociationMaps =
1721-
convertReassociationMapsToIndices($_builder, reassociation);
1757+
convertReassociationMapsToIndices(reassociation);
17221758
build($_builder, $_state, src, reassociationMaps, attrs);
17231759
}]>,
17241760

@@ -1736,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17361772
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17371773
[{
17381774
auto reassociationMaps =
1739-
convertReassociationMapsToIndices($_builder, reassociation);
1775+
convertReassociationMapsToIndices(reassociation);
17401776
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
17411777
}]>
17421778
];

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
10621062
Tensor_Op<mnemonic, !listconcat(traits, [
10631063
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
10641064
Pure])>,
1065-
Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
1066-
Results<(outs AnyRankedTensor:$result)> {
1065+
Results<(outs AnyTensor:$result)> {
10671066

10681067
code commonExtraClassDeclaration = [{
10691068
static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
10861085
}
10871086
}];
10881087

1089-
let assemblyFormat = [{
1090-
$src $reassociation attr-dict `:` type($src) `into` type($result)
1091-
}];
1092-
10931088
let hasFolder = 1;
10941089
let hasCanonicalizer = 1;
10951090
let hasVerifier = 1;
@@ -1102,50 +1097,83 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
11021097
rank than the operand `src` whose dimension sizes are a reassociation of
11031098
`src`.
11041099

1105-
A reassociation is defined as a continuous grouping of dimensions. It is
1106-
represented with an array of DenseI64ArrayAttr attribute. Entries in the
1107-
array are referred to as reassociation maps.
1100+
A reassociation is defined as a continuous grouping of dimensions and is
1101+
represented with an array of DenseI64ArrayAttr attribute. The reassociation
1102+
maps applied to the result tensor with the higher rank must result in the
1103+
operand tensor with the smaller rank.
11081104

1109-
The reassociation maps are applied to the result shape to obtain the operand
1110-
shape.
1105+
The representation for the output shape supports a partially-static
1106+
specification via attributes specified through the `static_output_shape`
1107+
argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1108+
corresponding entry has a dynamic value. There must be exactly as many SSA
1109+
inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1110+
`static_output_shape`.
11111111

11121112
Example:
11131113

11141114
```mlir
11151115
// Dimension expansion i -> (i', j') and (k) -> (k')
1116-
%b = tensor.expand_shape %a [[0, 1], [2]]
1117-
: tensor<?x?xf32> into tensor<?x?x?xf32>
1116+
%b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
1117+
: tensor<?x32xf32> into tensor<?x?x32xf32>
11181118
```
11191119
}];
1120+
1121+
let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
1122+
Variadic<Index>:$output_shape,
1123+
DenseI64ArrayAttr:$static_output_shape);
1124+
1125+
let assemblyFormat = [{
1126+
$src $reassociation `output_shape`
1127+
custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1128+
type($src) `into` type($result)
1129+
}];
1130+
11201131
let builders = [
11211132
// Builders using ReassociationIndices.
11221133
OpBuilder<(ins "Type":$resultType, "Value":$src,
11231134
"ArrayRef<ReassociationIndices>":$reassociation,
1124-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1135+
"ArrayRef<OpFoldResult>":$outputShape)>,
1136+
1137+
// It will infer output shape using inferOutputShape() method.
1138+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1139+
"ArrayRef<ReassociationIndices>":$reassociation)>,
1140+
1141+
// Builder using ReassociationExprs.
1142+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1143+
"ArrayRef<ReassociationExprs>":$reassociation),
11251144
[{
1126-
build($_builder, $_state, resultType, src, attrs);
1127-
$_state.addAttribute("reassociation",
1128-
getReassociationIndicesAttribute($_builder, reassociation));
1145+
auto reassociationIndices =
1146+
convertReassociationMapsToIndices(reassociation);
1147+
build($_builder, $_state, resultType, src, reassociationIndices);
11291148
}]>,
11301149
OpBuilder<(ins "Type":$resultType, "Value":$src,
11311150
"ArrayRef<ReassociationExprs>":$reassociation,
1132-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1151+
"ArrayRef<OpFoldResult>":$outputShape),
11331152
[{
1134-
auto reassociationMaps =
1135-
convertReassociationMapsToIndices($_builder, reassociation);
1136-
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1153+
auto reassociationIndices =
1154+
convertReassociationMapsToIndices(reassociation);
1155+
build($_builder, $_state, resultType, src, reassociationIndices,
1156+
outputShape);
11371157
}]>
11381158
];
11391159

11401160
let extraClassDeclaration = commonExtraClassDeclaration # [{
11411161
int64_t getCorrespondingSourceDim(int64_t resultDim);
1162+
1163+
// Infer the output shape for a tensor.expand_shape when it is possible
1164+
// to do so.
1165+
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
1166+
OpBuilder &b, Location loc, RankedTensorType expandedType,
1167+
ArrayRef<ReassociationIndices> reassociation,
1168+
ArrayRef<OpFoldResult> inputShape);
11421169
}];
11431170

11441171
let hasVerifier = 1;
11451172
}
11461173

11471174
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11481175
let summary = "operation to produce a tensor with a smaller rank";
1176+
let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
11491177
let description = [{
11501178
The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
11511179
rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1163,6 +1191,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11631191
: tensor<?x?x?xf32> into tensor<?x?xf32>
11641192
```
11651193
}];
1194+
1195+
let assemblyFormat = [{
1196+
$src $reassociation attr-dict `:` type($src) `into` type($result)
1197+
}];
1198+
11661199
let builders = [
11671200
// Builders for a contracting reshape whose result type is computed from
11681201
// `src` and `reassociation`.
@@ -1174,7 +1207,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11741207
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
11751208
[{
11761209
auto reassociationMaps =
1177-
convertReassociationMapsToIndices($_builder, reassociation);
1210+
convertReassociationMapsToIndices(reassociation);
11781211
build($_builder, $_state, src, reassociationMaps, attrs);
11791212
}]>,
11801213

@@ -1192,7 +1225,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11921225
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
11931226
[{
11941227
auto reassociationMaps =
1195-
convertReassociationMapsToIndices($_builder, reassociation);
1228+
convertReassociationMapsToIndices(reassociation);
11961229
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
11971230
}]>
11981231
];

0 commit comments

Comments
 (0)