Skip to content

Commit 67ef4ae

Browse files
authored
[MLIR][Tensor,MemRef] Fold expand_shape and collapse_shape if identity (#80658)
Before: op verifiers failed if the input and output ranks were the same (i.e. no expansion or collapse). This behavior requires users of these shape ops to verify manually that they are not creating identity versions of these ops every time they build them -- problematic. This PR removes this strict verification, and introduces folders for the the identity cases. The PR also removes the special case handling of rank-0 tensors for expand_shape and collapse_shape, there doesn't seem to be any reason to treat them differently.
1 parent 8d61f82 commit 67ef4ae

File tree

9 files changed

+152
-109
lines changed

9 files changed

+152
-109
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
641641
let summary = "non-blocking DMA operation that starts a transfer";
642642
let description = [{
643643
Syntax:
644-
644+
645645
```
646646
operation ::= `memref.dma_start` ssa-use`[`ssa-use-list`]` `,`
647647
ssa-use`[`ssa-use-list`]` `,` ssa-use `,`

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,21 +1098,18 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
10981098
def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
10991099
let summary = "operation to produce a tensor with a higher rank";
11001100
let description = [{
1101-
The `tensor.expand_shape` op produces a new tensor with a higher
1102-
rank whose sizes are a reassociation of the original `src`.
1101+
The `tensor.expand_shape` op produces a tensor of higher (or equal)
1102+
rank than the operand `src` whose dimension sizes are a reassociation of
1103+
`src`.
11031104

1104-
A reassociation is defined as a continuous grouping of dimensions and is
1105-
represented with an array of DenseI64ArrayAttr attribute.
1106-
1107-
The verification rule is that the reassociation maps are applied to the
1108-
result tensor with the higher rank to obtain the operand tensor with the
1109-
smaller rank.
1105+
A reassociation is defined as a continuous grouping of dimensions. It is
1106+
represented with an array of DenseI64ArrayAttr attribute. Entries in the
1107+
array are referred to as reassociation maps.
11101108

1111-
The operand tensor type of a reshape can be zero-ranked if the result
1112-
tensor type is statically shaped with all dimensions being unit extent. In
1113-
such cases the reassociation map is empty.
1109+
The reassociation maps are applied to the result shape to obtain the operand
1110+
shape.
11141111

1115-
Examples:
1112+
Example:
11161113

11171114
```mlir
11181115
// Dimension expansion i -> (i', j') and (k) -> (k')
@@ -1150,21 +1147,15 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
11501147
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11511148
let summary = "operation to produce a tensor with a smaller rank";
11521149
let description = [{
1153-
The `tensor.collapse_shape` op produces a new tensor with a smaller
1154-
rank whose sizes are a reassociation of the original `src`.
1150+
The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
1151+
rank whose dimension sizes are a reassociation of the original `src` dimensions.
11551152

11561153
A reassociation is defined as a continuous grouping of dimensions and is
1157-
represented with an array of DenseI64ArrayAttr attribute.
1154+
represented by an array of DenseI64ArrayAttr attribute. The reassociation
1155+
maps are applied to the operand shape to obtain the result shape.
11581156

1159-
The verification rule is that the reassociation maps are applied to the
1160-
operand tensor with the higher rank to obtain the result tensor with the
1161-
smaller rank.
11621157

1163-
The result tensor type of a reshape can be zero-ranked if the operand
1164-
tensor type is statically shaped with all dimensions being unit extent. In
1165-
such case the reassociation map is empty.
1166-
1167-
Examples:
1158+
Example:
11681159

11691160
```mlir
11701161
// Dimension collapse (i, j) -> i' and k -> k'
@@ -1841,7 +1832,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
18411832
and optionally transposes the tiled source tensor dimensions.
18421833

18431834
`inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
1844-
being tiled, where `0 < k <= n`. The order of the dimensions matters:
1835+
being tiled, where `0 < k <= n`. The order of the dimensions matters:
18451836
- The tiled dimensions (of size `inner_tiles`) are added to the end of the result
18461837
tensor in the order in which they appear in `inner_dims_pos`.
18471838
- `inner_dims_pos[i]` specifies the source tensor dimension tiled by

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,21 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
8585
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
8686
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
8787
ArrayRef<Attribute> operands) {
88-
// Fold producer-consumer reshape ops that where the operand type of the
88+
89+
if (reshapeOp.getSrcType() == reshapeOp.getType())
90+
return reshapeOp.getSrc();
91+
92+
// Fold producer-consumer reshape ops where the operand type of the
8993
// producer is same as the return type of the consumer.
9094
auto reshapeSrcOp =
9195
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
9296
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
9397
return reshapeSrcOp.getSrc();
98+
9499
// Reshape of a constant can be replaced with a new constant.
95-
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
100+
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
96101
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
97-
}
102+
98103
return nullptr;
99104
}
100105

@@ -103,41 +108,36 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
103108
template <typename Op, typename T>
104109
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
105110
T collapsedType, bool isExpansion) {
111+
106112
unsigned expandedRank = expandedType.getRank();
107113
unsigned collapsedRank = collapsedType.getRank();
108114
if (expandedRank < collapsedRank)
109-
return op.emitOpError("expected the type ")
110-
<< expandedType
111-
<< " to have higher rank than the type = " << collapsedType;
112-
if (expandedRank == 0)
113-
return op.emitOpError("expected non-zero memref ranks");
114-
if (expandedRank == collapsedRank)
115-
return op.emitOpError("expected to collapse or expand dims");
116-
117-
if (collapsedRank == 0) {
118-
// If collapsed rank is 0, then expanded type must be static shaped and of
119-
// sizes 1.
120-
if (llvm::any_of(expandedType.getShape(),
121-
[](int64_t dim) -> bool { return dim != 1; }))
122-
return op.emitOpError("invalid to reshape tensor/memref with non-unit "
123-
"extent dimensions to zero-rank tensor/memref");
124-
return success();
125-
}
115+
return op.emitOpError("expected the expanded type, ")
116+
<< expandedType << " to have a higher (or same) rank "
117+
<< "than the collapsed type, " << collapsedType << '.';
118+
126119
if (collapsedRank != op.getReassociation().size())
127-
return op.emitOpError("expected rank of the collapsed type(")
128-
<< collapsedRank << ") to be the number of reassociation maps("
129-
<< op.getReassociation().size() << ")";
120+
return op.emitOpError("expected collapsed rank (")
121+
<< collapsedRank << ") to equal the number of reassociation maps ("
122+
<< op.getReassociation().size() << ").";
123+
130124
auto maps = op.getReassociationMaps();
131125
for (auto it : llvm::enumerate(maps))
132126
if (it.value().getNumDims() != expandedRank)
133127
return op.emitOpError("expected reassociation map #")
134-
<< it.index() << " of same rank as expanded memref("
135-
<< expandedRank << "), but got " << it.value().getNumDims();
128+
<< it.index() << " to have size equal to the expanded rank ("
129+
<< expandedRank << "), but it is " << it.value().getNumDims()
130+
<< '.';
131+
136132
int invalidIdx = 0;
137133
if (!isReassociationValid(maps, &invalidIdx))
138134
return op.emitOpError("expected reassociation map #")
139-
<< invalidIdx << " to be valid and contiguous";
140-
return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
135+
<< invalidIdx << " to be valid and contiguous.";
136+
137+
return reshapeLikeShapesAreCompatible(
138+
[&](const Twine &msg) { return op->emitOpError(msg); },
139+
collapsedType.getShape(), expandedType.getShape(),
140+
op.getReassociationIndices(), isExpansion);
141141
}
142142

143143
/// Verify that shapes of the reshaped types using following rules
@@ -153,16 +153,6 @@ LogicalResult reshapeLikeShapesAreCompatible(
153153
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
154154
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
155155

156-
template <typename OpTy>
157-
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
158-
ShapedType expandedType,
159-
bool isExpandingReshape) {
160-
return reshapeLikeShapesAreCompatible(
161-
[&](const Twine &msg) { return op->emitOpError(msg); },
162-
collapsedType.getShape(), expandedType.getShape(),
163-
op.getReassociationIndices(), isExpandingReshape);
164-
}
165-
166156
/// Returns true iff the type is a MemRefType and has a non-identity layout.
167157
bool hasNonIdentityLayout(Type type);
168158

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,9 +2224,13 @@ LogicalResult ExpandShapeOp::verify() {
22242224
MemRefType srcType = getSrcType();
22252225
MemRefType resultType = getResultType();
22262226

2227-
if (srcType.getRank() >= resultType.getRank())
2228-
return emitOpError("expected rank expansion, but found source rank ")
2229-
<< srcType.getRank() << " >= result rank " << resultType.getRank();
2227+
if (srcType.getRank() > resultType.getRank()) {
2228+
auto r0 = srcType.getRank();
2229+
auto r1 = resultType.getRank();
2230+
return emitOpError("has source rank ")
2231+
<< r0 << " and result rank " << r1 << ". This is not an expansion ("
2232+
<< r0 << " > " << r1 << ").";
2233+
}
22302234

22312235
// Verify result shape.
22322236
if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
@@ -2378,9 +2382,13 @@ LogicalResult CollapseShapeOp::verify() {
23782382
MemRefType srcType = getSrcType();
23792383
MemRefType resultType = getResultType();
23802384

2381-
if (srcType.getRank() <= resultType.getRank())
2382-
return emitOpError("expected rank reduction, but found source rank ")
2383-
<< srcType.getRank() << " <= result rank " << resultType.getRank();
2385+
if (srcType.getRank() < resultType.getRank()) {
2386+
auto r0 = srcType.getRank();
2387+
auto r1 = resultType.getRank();
2388+
return emitOpError("has source rank ")
2389+
<< r0 << " and result rank " << r1 << ". This is not a collapse ("
2390+
<< r0 << " < " << r1 << ").";
2391+
}
23842392

23852393
// Verify result shape.
23862394
if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,22 +1656,10 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
16561656
}
16571657

16581658
LogicalResult ExpandShapeOp::verify() {
1659-
auto srcType = getSrcType();
1660-
auto resultType = getResultType();
1661-
if (srcType.getRank() >= resultType.getRank())
1662-
return emitOpError("expected rank expansion, but found source rank ")
1663-
<< srcType.getRank() << " >= result rank " << resultType.getRank();
1664-
16651659
return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
16661660
}
16671661

16681662
LogicalResult CollapseShapeOp::verify() {
1669-
auto srcType = getSrcType();
1670-
auto resultType = getResultType();
1671-
if (srcType.getRank() <= resultType.getRank())
1672-
return emitOpError("expected rank reduction, but found source rank ")
1673-
<< srcType.getRank() << " <= result rank " << resultType.getRank();
1674-
16751663
return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
16761664
}
16771665

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
11
// RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s
22

3+
4+
// CHECK-LABEL: collapse_shape_identity_fold
5+
// CHECK-NEXT: return
6+
func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> {
7+
%0 = memref.collapse_shape %arg0 [[0]] : memref<5xi8> into memref<5xi8>
8+
return %0 : memref<5xi8>
9+
}
10+
11+
// -----
12+
13+
// CHECK-LABEL: expand_shape_identity_fold
14+
// CHECK-NEXT: return
15+
func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> {
16+
%0 = memref.expand_shape %arg0 [[0], [1]] : memref<5x4xi8> into memref<5x4xi8>
17+
return %0 : memref<5x4xi8>
18+
}
19+
20+
// -----
21+
22+
// CHECK-LABEL: collapse_expand_rank0_cancel
23+
// CHECK-NEXT: return
24+
func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> {
25+
%0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8>
26+
%1 = memref.expand_shape %0 [] : memref<i8> into memref<1x1xi8>
27+
return %1 : memref<1x1xi8>
28+
}
29+
30+
// -----
31+
332
// CHECK-LABEL: func @subview_of_size_memcast
433
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
534
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -415,20 +415,6 @@ func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {
415415

416416
// -----
417417

418-
func.func @expand_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
419-
// expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
420-
%0 = memref.expand_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
421-
}
422-
423-
// -----
424-
425-
func.func @collapse_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
426-
// expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
427-
%0 = memref.collapse_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
428-
}
429-
430-
// -----
431-
432418
func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
433419
// expected-error @+1 {{op reassociation index 2 is out of bounds}}
434420
%0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref<?xf32> into memref<4x?xf32>
@@ -462,6 +448,34 @@ func.func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
462448

463449
// -----
464450

451+
// An (invalid) attempt at using collapse_shape to increase the rank might look
452+
// like this. Verify that a sensible error is emitted in this case.
453+
func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>) {
454+
// expected-error @+1 {{'memref.collapse_shape' op has source rank 1 and result rank 2. This is not a collapse (1 < 2)}}
455+
%0 = memref.collapse_shape %arg0 [[0], [0]] :
456+
memref<?xf32> into memref<?x?xf32>
457+
}
458+
459+
// -----
460+
461+
// An (invalid) attempt at using expand_shape to reduce the rank might look
462+
// like this. Verify that a sensible error is emitted in this case.
463+
func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
464+
// expected-error @+1 {{'memref.expand_shape' op has source rank 3 and result rank 2. This is not an expansion (3 > 2)}}
465+
%0 = memref.expand_shape %arg0 [[0], [1], [1]] :
466+
memref<2x3x1xf32> into memref<2x3xf32>
467+
}
468+
469+
// -----
470+
471+
func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?x?xf32>) {
472+
// expected-error @+1 {{reassociation indices must be contiguous}}
473+
%0 = memref.collapse_shape %arg0 [[1], [0]] :
474+
memref<?x?xf32> into memref<?x?xf32>
475+
}
476+
477+
// -----
478+
465479
func.func @collapse_shape_reshaping_non_contiguous(
466480
%arg0: memref<3x4x5xf32, strided<[270, 50, 10], offset: 0>>) {
467481
// expected-error @+1 {{invalid source layout map or collapsing non-contiguous dims}}

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,42 @@
11
// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
22

3+
4+
// CHECK-LABEL: expand_shape_identity_fold
5+
// CHECK-NEXT: return
6+
func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
7+
%0 = tensor.expand_shape %arg0 [[0]] : tensor<5xf32> into tensor<5xf32>
8+
return %0 : tensor<5xf32>
9+
}
10+
11+
// -----
12+
13+
// CHECK-LABEL: expand_shape_rank0_identity_fold
14+
// CHECK-NEXT: return
15+
func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
16+
%0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<f32>
17+
return %0 : tensor<f32>
18+
}
19+
20+
// -----
21+
22+
// CHECK-LABEL: collapse_shape_identity_fold
23+
// CHECK-NEXT: return
24+
func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
25+
%0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
26+
return %0 : tensor<5x4xf32>
27+
}
28+
29+
// -----
30+
31+
// CHECK-LABEL: collapse_shape_rank0_identity_fold
32+
// CHECK-NEXT: return
33+
func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
34+
%0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
35+
return %0 : tensor<f32>
36+
}
37+
38+
// -----
39+
340
// CHECK-LABEL: @tensor_bitcast_chain_ok
441
// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
542
func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
@@ -2092,7 +2129,7 @@ func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
20922129

20932130
// Chain: NC -> NCnc -> NCnc -> NC
20942131
// CHECK: func.func @unpack_pack(
2095-
// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
2132+
// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
20962133
// CHECK: return %[[T]] : tensor<128x128xf32>
20972134
func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
20982135
%tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -343,20 +343,6 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
343343

344344
// -----
345345

346-
func.func @expand_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
347-
// expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
348-
%0 = tensor.expand_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
349-
}
350-
351-
// -----
352-
353-
func.func @collapse_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
354-
// expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
355-
%0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
356-
}
357-
358-
// -----
359-
360346
func.func @rank(%0: f32) {
361347
// expected-error@+1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
362348
"tensor.rank"(%0): (f32)->index

0 commit comments

Comments
 (0)