Skip to content

Commit 7906997

Browse files
committed
test
1 parent 847a899 commit 7906997

File tree

5 files changed

+47
-3
lines changed

5 files changed

+47
-3
lines changed

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,6 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
111111
<< " to have higher rank than the type = " << collapsedType;
112112
if (expandedRank == 0)
113113
return op.emitOpError("expected non-zero memref ranks");
114-
if (expandedRank == collapsedRank)
115-
return op.emitOpError("expected to collapse or expand dims");
116114

117115
if (collapsedRank == 0) {
118116
// If collapsed rank is 0, then expanded type must be static shaped and of

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,11 +2448,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
24482448
}
24492449

24502450
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2451+
if (getSrcType() == getType())
2452+
return getSrc();
24512453
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
24522454
adaptor.getOperands());
24532455
}
24542456

24552457
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2458+
if (getSrcType() == getType())
2459+
return getSrc();
24562460
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
24572461
adaptor.getOperands());
24582462
}

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,11 +1860,15 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
18601860
}
18611861

18621862
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1863+
if (getSrcType() == getType())
1864+
return getSrc();
18631865
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
18641866
adaptor.getOperands());
18651867
}
18661868

18671869
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
1870+
if (getSrcType() == getType())
1871+
return getSrc();
18681872
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
18691873
adaptor.getOperands());
18701874
}

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
11
// RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s
22

3+
4+
// CHECK-LABEL: collapse_shape_identity_fold
5+
// CHECK-NEXT: return
6+
func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> {
7+
%0 = memref.collapse_shape %arg0 [[0]] : memref<5xi8> into memref<5xi8>
8+
return %0 : memref<5xi8>
9+
}
10+
11+
// -----
12+
13+
// CHECK-LABEL: expand_shape_identity_fold
14+
// CHECK-NEXT: return
15+
func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> {
16+
%0 = memref.expand_shape %arg0 [[0], [1]] : memref<5x4xi8> into memref<5x4xi8>
17+
return %0 : memref<5x4xi8>
18+
}
19+
20+
// -----
21+
322
// CHECK-LABEL: func @subview_of_size_memcast
423
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
524
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
11
// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
22

3+
4+
// CHECK-LABEL: expand_shape_identity_fold
5+
// CHECK-NEXT: return
6+
func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
7+
%0 = tensor.expand_shape %arg0 [[0]] : tensor<5xf32> into tensor<5xf32>
8+
return %0 : tensor<5xf32>
9+
}
10+
11+
// -----
12+
13+
// CHECK-LABEL: collapse_shape_identity_fold
14+
// CHECK-NEXT: return
15+
func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
16+
%0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
17+
return %0 : tensor<5x4xf32>
18+
}
19+
20+
// -----
21+
322
// CHECK-LABEL: @tensor_bitcast_chain_ok
423
// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
524
func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
@@ -2069,7 +2088,7 @@ func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
20692088

20702089
// Chain: NC -> NCnc -> NCnc -> NC
20712090
// CHECK: func.func @unpack_pack(
2072-
// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
2091+
// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
20732092
// CHECK: return %[[T]] : tensor<128x128xf32>
20742093
func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
20752094
%tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>

0 commit comments

Comments
 (0)