Skip to content

Commit 75ddea2

Browse files
committed
updates. additional testing, doc corrections, remove rank-0 special case handling
1 parent 7906997 commit 75ddea2

File tree

8 files changed

+100
-68
lines changed

8 files changed

+100
-68
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
641641
let summary = "non-blocking DMA operation that starts a transfer";
642642
let description = [{
643643
Syntax:
644-
644+
645645
```
646646
operation ::= `memref.dma_start` ssa-use`[`ssa-use-list`]` `,`
647647
ssa-use`[`ssa-use-list`]` `,` ssa-use `,`

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,21 +1098,18 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
10981098
def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
10991099
let summary = "operation to produce a tensor with a higher rank";
11001100
let description = [{
1101-
The `tensor.expand_shape` op produces a new tensor with a higher
1102-
rank whose sizes are a reassociation of the original `src`.
1101+
The `tensor.expand_shape` op produces a tensor of higher (or equal)
1102+
rank than the operand `src` whose dimension sizes are a reassociation of
1103+
`src`.
11031104

1104-
A reassociation is defined as a continuous grouping of dimensions and is
1105-
represented with an array of DenseI64ArrayAttr attribute.
1106-
1107-
The verification rule is that the reassociation maps are applied to the
1108-
result tensor with the higher rank to obtain the operand tensor with the
1109-
smaller rank.
1105+
A reassociation is defined as a continuous grouping of dimensions. It is
1106+
represented with an array of DenseI64ArrayAttr attribute. Entries in the
1107+
array are referred to as reassociation maps.
11101108

1111-
The operand tensor type of a reshape can be zero-ranked if the result
1112-
tensor type is statically shaped with all dimensions being unit extent. In
1113-
such cases the reassociation map is empty.
1109+
The reassociation maps are applied to the result shape to obtain the operand
1110+
shape.
11141111

1115-
Examples:
1112+
Example:
11161113

11171114
```mlir
11181115
// Dimension expansion i -> (i', j') and (k) -> (k')
@@ -1150,21 +1147,15 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
11501147
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
11511148
let summary = "operation to produce a tensor with a smaller rank";
11521149
let description = [{
1153-
The `tensor.collapse_shape` op produces a new tensor with a smaller
1154-
rank whose sizes are a reassociation of the original `src`.
1150+
The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
1151+
rank whose dimension sizes are a reassociation of the original `src` dimensions.
11551152

11561153
A reassociation is defined as a continuous grouping of dimensions and is
1157-
represented with an array of DenseI64ArrayAttr attribute.
1154+
represented by an array of DenseI64ArrayAttr attribute. The reassociation
1155+
maps are applied to the operand shape to obtain the result shape.
11581156

1159-
The verification rule is that the reassociation maps are applied to the
1160-
operand tensor with the higher rank to obtain the result tensor with the
1161-
smaller rank.
11621157

1163-
The result tensor type of a reshape can be zero-ranked if the operand
1164-
tensor type is statically shaped with all dimensions being unit extent. In
1165-
such case the reassociation map is empty.
1166-
1167-
Examples:
1158+
Example:
11681159

11691160
```mlir
11701161
// Dimension collapse (i, j) -> i' and k -> k'
@@ -1841,7 +1832,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
18411832
and optionally transposes the tiled source tensor dimensions.
18421833

18431834
`inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
1844-
being tiled, where `0 < k <= n`. The order of the dimensions matters:
1835+
being tiled, where `0 < k <= n`. The order of the dimensions matters:
18451836
- The tiled dimensions (of size `inner_tiles`) are added to the end of the result
18461837
tensor in the order in which they appear in `inner_dims_pos`.
18471838
- `inner_dims_pos[i]` specifies the source tensor dimension tiled by

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,21 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
8585
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
8686
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
8787
ArrayRef<Attribute> operands) {
88-
// Fold producer-consumer reshape ops that where the operand type of the
88+
89+
if (reshapeOp.getSrcType() == reshapeOp.getType())
90+
return reshapeOp.getSrc();
91+
92+
// Fold producer-consumer reshape ops where the operand type of the
8993
// producer is same as the return type of the consumer.
9094
auto reshapeSrcOp =
9195
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
9296
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
9397
return reshapeSrcOp.getSrc();
98+
9499
// Reshape of a constant can be replaced with a new constant.
95-
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
100+
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
96101
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
97-
}
102+
98103
return nullptr;
99104
}
100105

@@ -103,39 +108,37 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
103108
template <typename Op, typename T>
104109
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
105110
T collapsedType, bool isExpansion) {
111+
106112
unsigned expandedRank = expandedType.getRank();
107113
unsigned collapsedRank = collapsedType.getRank();
108114
if (expandedRank < collapsedRank)
109-
return op.emitOpError("expected the type ")
110-
<< expandedType
111-
<< " to have higher rank than the type = " << collapsedType;
112-
if (expandedRank == 0)
113-
return op.emitOpError("expected non-zero memref ranks");
114-
115-
if (collapsedRank == 0) {
116-
// If collapsed rank is 0, then expanded type must be static shaped and of
117-
// sizes 1.
118-
if (llvm::any_of(expandedType.getShape(),
119-
[](int64_t dim) -> bool { return dim != 1; }))
120-
return op.emitOpError("invalid to reshape tensor/memref with non-unit "
121-
"extent dimensions to zero-rank tensor/memref");
122-
return success();
123-
}
115+
return op.emitOpError("expected the expanded type, ")
116+
<< expandedType << " to have a higher (or same) rank "
117+
<< "than the collapsed type, " << collapsedType << '.';
118+
124119
if (collapsedRank != op.getReassociation().size())
125-
return op.emitOpError("expected rank of the collapsed type(")
126-
<< collapsedRank << ") to be the number of reassociation maps("
127-
<< op.getReassociation().size() << ")";
120+
return op.emitOpError("expected collapsed rank (")
121+
<< collapsedRank << ") to equal the number of reassociation maps ("
122+
<< op.getReassociation().size() << ").";
123+
128124
auto maps = op.getReassociationMaps();
129125
for (auto it : llvm::enumerate(maps))
130126
if (it.value().getNumDims() != expandedRank)
131127
return op.emitOpError("expected reassociation map #")
132-
<< it.index() << " of same rank as expanded memref("
133-
<< expandedRank << "), but got " << it.value().getNumDims();
128+
<< it.index() << " to have size equal to the expanded rank ("
129+
<< expandedRank << "), but it is " << it.value().getNumDims()
130+
<< '.';
131+
134132
int invalidIdx = 0;
135133
if (!isReassociationValid(maps, &invalidIdx))
136134
return op.emitOpError("expected reassociation map #")
137-
<< invalidIdx << " to be valid and contiguous";
138-
return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
135+
<< invalidIdx << " to be valid and contiguous.";
136+
137+
return reshapeLikeShapesAreCompatible(
138+
[&](const Twine &msg) { return op->emitOpError(msg); },
139+
collapsedType.getShape(), expandedType.getShape(),
140+
op.getReassociationIndices(), isExpansion);
141+
139142
}
140143

141144
/// Verify that shapes of the reshaped types using following rules
@@ -151,16 +154,6 @@ LogicalResult reshapeLikeShapesAreCompatible(
151154
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
152155
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
153156

154-
template <typename OpTy>
155-
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
156-
ShapedType expandedType,
157-
bool isExpandingReshape) {
158-
return reshapeLikeShapesAreCompatible(
159-
[&](const Twine &msg) { return op->emitOpError(msg); },
160-
collapsedType.getShape(), expandedType.getShape(),
161-
op.getReassociationIndices(), isExpandingReshape);
162-
}
163-
164157
/// Returns true iff the type is a MemRefType and has a non-identity layout.
165158
bool hasNonIdentityLayout(Type type);
166159

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,15 +2448,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
24482448
}
24492449

24502450
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2451-
if (getSrcType() == getType())
2452-
return getSrc();
24532451
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
24542452
adaptor.getOperands());
24552453
}
24562454

24572455
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2458-
if (getSrcType() == getType())
2459-
return getSrc();
24602456
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
24612457
adaptor.getOperands());
24622458
}

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,15 +1860,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
18601860
}
18611861

18621862
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1863-
if (getSrcType() == getType())
1864-
return getSrc();
18651863
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
18661864
adaptor.getOperands());
18671865
}
18681866

18691867
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
1870-
if (getSrcType() == getType())
1871-
return getSrc();
18721868
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
18731869
adaptor.getOperands());
18741870
}

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8>
1919

2020
// -----
2121

22+
// CHECK-LABEL: collapse_expand_rank0_cancel
23+
// CHECK-NEXT: return
24+
func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> {
25+
%0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8>
26+
%1 = memref.expand_shape %0 [] : memref<i8> into memref<1x1xi8>
27+
return %1 : memref<1x1xi8>
28+
}
29+
30+
// -----
31+
2232
// CHECK-LABEL: func @subview_of_size_memcast
2333
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
2434
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,34 @@ func.func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
448448

449449
// -----
450450

451+
// An (invalid) attempt at using collapse_shape to increase the rank might look
452+
// like this. Verify that a sensible error is emitted in this case.
453+
func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>) {
454+
// expected-error @+1 {{reassociation indices must be contiguous}}
455+
%0 = memref.collapse_shape %arg0 [[0], [0]] :
456+
memref<?xf32> into memref<?x?xf32>
457+
}
458+
459+
// -----
460+
461+
// An (invalid) attempt at using expand_shape to reduce the rank might look
462+
// like this. Verify that a sensible error is emitted in this case.
463+
func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
464+
// expected-error @+1 {{reassociation indices must be contiguous}}
465+
%0 = memref.expand_shape %arg0 [[0], [1], [1]] :
466+
memref<2x3x1xf32> into memref<2x3xf32>
467+
}
468+
469+
// -----
470+
471+
func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?x?xf32>) {
472+
// expected-error @+1 {{reassociation indices must be contiguous}}
473+
%0 = memref.collapse_shape %arg0 [[1], [0]] :
474+
memref<?x?xf32> into memref<?x?xf32>
475+
}
476+
477+
// -----
478+
451479
func.func @collapse_shape_reshaping_non_contiguous(
452480
%arg0: memref<3x4x5xf32, strided<[270, 50, 10], offset: 0>>) {
453481
// expected-error @+1 {{invalid source layout map or collapsing non-contiguous dims}}

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
1010

1111
// -----
1212

13+
// CHECK-LABEL: expand_shape_rank0_identity_fold
14+
// CHECK-NEXT: return
15+
func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
16+
%0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<f32>
17+
return %0 : tensor<f32>
18+
}
19+
20+
// -----
21+
1322
// CHECK-LABEL: collapse_shape_identity_fold
1423
// CHECK-NEXT: return
1524
func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
@@ -19,6 +28,15 @@ func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf
1928

2029
// -----
2130

31+
// CHECK-LABEL: collapse_shape_rank0_identity_fold
32+
// CHECK-NEXT: return
33+
func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
34+
%0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
35+
return %0 : tensor<f32>
36+
}
37+
38+
// -----
39+
2240
// CHECK-LABEL: @tensor_bitcast_chain_ok
2341
// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
2442
func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {

0 commit comments

Comments
 (0)