Skip to content

Commit 76d71f3

Browse files
committed
Revert "[mlir][Vector] Extend xfer drop unit dim patterns"
This reverts commit a53cd03. This commit is exposing some implementation gaps in other patterns. Reverting for now.
1 parent 3ab3671 commit 76d71f3

File tree

2 files changed

+15
-151
lines changed

2 files changed

+15
-151
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 15 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ class TransferOptimization {
6363
std::vector<Operation *> opToErase;
6464
};
6565

66-
} // namespace
6766
/// Return true if there is a path from start operation to dest operation,
6867
/// otherwise return false. The operations have to be in the same region.
6968
bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
@@ -289,25 +288,14 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
289288
return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
290289
}
291290

292-
/// Returns a copy of `shape` without unit dims.
293-
static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
294-
SmallVector<int64_t> reducedShape;
295-
llvm::copy_if(shape, std::back_inserter(reducedShape),
296-
[](int64_t dimSize) { return dimSize != 1; });
297-
return reducedShape;
298-
}
299-
300291
/// Returns true if all values are `arith.constant 0 : index`
301292
static bool isZero(Value v) {
302293
auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
303294
return cst && cst.value() == 0;
304295
}
305296

306-
namespace {
307-
308-
/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
309-
/// inserting a memref.subview dropping those unit dims. The vector shapes are
310-
/// also reduced accordingly.
297+
/// Rewrites vector.transfer_read ops where the source has unit dims, by
298+
/// inserting a memref.subview dropping those unit dims.
311299
class TransferReadDropUnitDimsPattern
312300
: public OpRewritePattern<vector::TransferReadOp> {
313301
using OpRewritePattern::OpRewritePattern;
@@ -329,15 +317,12 @@ class TransferReadDropUnitDimsPattern
329317
return failure();
330318
if (!transferReadOp.getPermutationMap().isMinorIdentity())
331319
return failure();
332-
// Check if the source shape can be further reduced.
333320
int reducedRank = getReducedRank(sourceType.getShape());
334321
if (reducedRank == sourceType.getRank())
335-
return failure();
336-
// Check if the reduced vector shape matches the reduced source shape.
337-
// Otherwise, this case is not supported yet.
338-
int vectorReducedRank = getReducedRank(vectorType.getShape());
339-
if (reducedRank != vectorReducedRank)
340-
return failure();
322+
return failure(); // The source shape can't be further reduced.
323+
if (reducedRank != vectorType.getRank())
324+
return failure(); // This pattern requires the vector shape to match the
325+
// reduced source shape.
341326
if (llvm::any_of(transferReadOp.getIndices(),
342327
[](Value v) { return !isZero(v); }))
343328
return failure();
@@ -346,22 +331,14 @@ class TransferReadDropUnitDimsPattern
346331
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
347332
SmallVector<Value> zeros(reducedRank, c0);
348333
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
349-
auto reducedVectorType = VectorType::get(
350-
getReducedShape(vectorType.getShape()), vectorType.getElementType());
351-
352-
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
353-
loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
354-
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
355-
loc, vectorType, newTransferReadOp);
356-
rewriter.replaceOp(transferReadOp, shapeCast);
357-
334+
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
335+
transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
358336
return success();
359337
}
360338
};
361339

362-
/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
363-
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
364-
/// vector shapes are also reduced accordingly.
340+
/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
341+
/// unit dims, by inserting a memref.subview dropping those unit dims.
365342
class TransferWriteDropUnitDimsPattern
366343
: public OpRewritePattern<vector::TransferWriteOp> {
367344
using OpRewritePattern::OpRewritePattern;
@@ -383,15 +360,12 @@ class TransferWriteDropUnitDimsPattern
383360
return failure();
384361
if (!transferWriteOp.getPermutationMap().isMinorIdentity())
385362
return failure();
386-
// Check if the destination shape can be further reduced.
387363
int reducedRank = getReducedRank(sourceType.getShape());
388364
if (reducedRank == sourceType.getRank())
389-
return failure();
390-
// Check if the reduced vector shape matches the reduced destination shape.
391-
// Otherwise, this case is not supported yet.
392-
int vectorReducedRank = getReducedRank(vectorType.getShape());
393-
if (reducedRank != vectorReducedRank)
394-
return failure();
365+
return failure(); // The source shape can't be further reduced.
366+
if (reducedRank != vectorType.getRank())
367+
return failure(); // This pattern requires the vector shape to match the
368+
// reduced source shape.
395369
if (llvm::any_of(transferWriteOp.getIndices(),
396370
[](Value v) { return !isZero(v); }))
397371
return failure();
@@ -400,20 +374,12 @@ class TransferWriteDropUnitDimsPattern
400374
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
401375
SmallVector<Value> zeros(reducedRank, c0);
402376
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
403-
VectorType reducedVectorType = VectorType::get(
404-
getReducedShape(vectorType.getShape()), vectorType.getElementType());
405-
406-
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
407-
loc, reducedVectorType, vector);
408377
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
409-
transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
410-
378+
transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
411379
return success();
412380
}
413381
};
414382

415-
} // namespace
416-
417383
/// Return true if the memref type has its inner dimension matching the given
418384
/// shape. Otherwise return false.
419385
static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
@@ -473,8 +439,6 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
473439
return success();
474440
}
475441

476-
namespace {
477-
478442
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
479443
/// memref.collapse_shape on the source so that the resulting
480444
/// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -768,7 +732,6 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
768732
return success();
769733
}
770734
};
771-
772735
} // namespace
773736

774737
void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,

mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,6 @@ func.func @transfer_read_rank_reducing(
1515
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
1616
// CHECK: vector.transfer_read %[[SUBVIEW]]
1717

18-
transform.sequence failures(propagate) {
19-
^bb1(%module_op: !pdl.operation):
20-
transform.vector.apply_rank_reducing_subview_patterns %module_op
21-
: (!pdl.operation) -> !pdl.operation
22-
}
23-
24-
// -----
25-
2618
func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
2719
%c0 = arith.constant 0 : index
2820
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -36,97 +28,6 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
3628
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
3729
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
3830

39-
transform.sequence failures(propagate) {
40-
^bb1(%module_op: !pdl.operation):
41-
transform.vector.apply_rank_reducing_subview_patterns %module_op
42-
: (!pdl.operation) -> !pdl.operation
43-
}
44-
45-
// -----
46-
47-
func.func @transfer_read_and_vector_rank_reducing(
48-
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
49-
%c0 = arith.constant 0 : index
50-
%cst = arith.constant 0.0 : f32
51-
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
52-
memref<1x1x3x2x1xf32>, vector<3x2x1xf32>
53-
return %v : vector<3x2x1xf32>
54-
}
55-
56-
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing
57-
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
58-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
59-
// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
60-
// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32>
61-
62-
transform.sequence failures(propagate) {
63-
^bb1(%module_op: !pdl.operation):
64-
transform.vector.apply_rank_reducing_subview_patterns %module_op
65-
: (!pdl.operation) -> !pdl.operation
66-
}
67-
68-
// -----
69-
70-
func.func @transfer_write_and_vector_rank_reducing(
71-
%arg : memref<1x1x3x2x1xf32>,
72-
%vec : vector<3x2x1xf32>) {
73-
%c0 = arith.constant 0 : index
74-
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
75-
vector<3x2x1xf32>, memref<1x1x3x2x1xf32>
76-
return
77-
}
78-
79-
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing
80-
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
81-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
82-
// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
83-
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32>
84-
85-
transform.sequence failures(propagate) {
86-
^bb1(%module_op: !transform.any_op):
87-
transform.vector.apply_rank_reducing_subview_patterns %module_op
88-
: (!transform.any_op) -> !transform.any_op
89-
}
90-
91-
// -----
92-
93-
func.func @transfer_read_and_vector_rank_reducing_to_0d(
94-
%arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> {
95-
%c0 = arith.constant 0 : index
96-
%cst = arith.constant 0.0 : f32
97-
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
98-
memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
99-
return %v : vector<1x1x1xf32>
100-
}
101-
102-
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d
103-
// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>
104-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
105-
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
106-
// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
107-
108-
transform.sequence failures(propagate) {
109-
^bb1(%module_op: !pdl.operation):
110-
transform.vector.apply_rank_reducing_subview_patterns %module_op
111-
: (!pdl.operation) -> !pdl.operation
112-
}
113-
114-
// -----
115-
116-
func.func @transfer_write_and_vector_rank_reducing_to_0d(
117-
%arg : memref<1x1x1x1x1xf32>,
118-
%vec : vector<1x1x1xf32>) {
119-
%c0 = arith.constant 0 : index
120-
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
121-
vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
122-
return
123-
}
124-
125-
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d
126-
// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32>
127-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
128-
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
129-
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
13031

13132
transform.sequence failures(propagate) {
13233
^bb1(%module_op: !transform.any_op):

0 commit comments

Comments
 (0)