Skip to content

Commit 132b12a

Browse files
committed
[mlir][Vector] Remove VectorLoadToMemrefLoadLowering and VectorStoreToMemrefStoreLowering
0-d vectors are supported now and so these patterns are no longer required.
1 parent 0965515 commit 132b12a

File tree

2 files changed

+10
-71
lines changed

2 files changed

+10
-71
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -492,60 +492,6 @@ struct TransferReadToVectorLoadLowering
492492
std::optional<unsigned> maxTransferRank;
493493
};
494494

495-
/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
496-
// TODO: we shouldn't cross the vector/scalar domains just for this
497-
// but atm we lack the infra to avoid it. Possible solutions include:
498-
// - go directly to LLVM + bitcast
499-
// - introduce a bitcast op and likely a new pointer dialect
500-
// - let memref.load/store additionally support the 0-d vector case
501-
// There are still deeper data layout issues lingering even in this
502-
// trivial case (for architectures for which this matters).
503-
struct VectorLoadToMemrefLoadLowering
504-
: public OpRewritePattern<vector::LoadOp> {
505-
using OpRewritePattern::OpRewritePattern;
506-
507-
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
508-
PatternRewriter &rewriter) const override {
509-
auto vecType = loadOp.getVectorType();
510-
if (vecType.getNumElements() != 1)
511-
return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
512-
513-
auto memrefLoad = rewriter.create<memref::LoadOp>(
514-
loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
515-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
516-
memrefLoad);
517-
return success();
518-
}
519-
};
520-
521-
/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
522-
struct VectorStoreToMemrefStoreLowering
523-
: public OpRewritePattern<vector::StoreOp> {
524-
using OpRewritePattern::OpRewritePattern;
525-
526-
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
527-
PatternRewriter &rewriter) const override {
528-
auto vecType = storeOp.getVectorType();
529-
if (vecType.getNumElements() != 1)
530-
return rewriter.notifyMatchFailure(storeOp, "not single element vector");
531-
532-
Value extracted;
533-
if (vecType.getRank() == 0) {
534-
// TODO: Unifiy once ExtractOp supports 0-d vectors.
535-
extracted = rewriter.create<vector::ExtractElementOp>(
536-
storeOp.getLoc(), storeOp.getValueToStore());
537-
} else {
538-
SmallVector<int64_t> indices(vecType.getRank(), 0);
539-
extracted = rewriter.create<vector::ExtractOp>(
540-
storeOp.getLoc(), storeOp.getValueToStore(), indices);
541-
}
542-
543-
rewriter.replaceOpWithNewOp<memref::StoreOp>(
544-
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
545-
return success();
546-
}
547-
};
548-
549495
/// Progressive lowering of transfer_write. This pattern supports lowering of
550496
/// `vector.transfer_write` to `vector.store` if all of the following hold:
551497
/// - Stride of most minor memref dimension must be 1.
@@ -645,7 +591,4 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
645591
patterns.add<TransferReadToVectorLoadLowering,
646592
TransferWriteToVectorStoreLowering>(patterns.getContext(),
647593
maxTransferRank, benefit);
648-
patterns
649-
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
650-
patterns.getContext(), benefit);
651594
}

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,13 @@
66
func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
77
%f0 = arith.constant 0.0 : f32
88

9-
// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
10-
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
9+
// CHECK-NEXT: %[[S:.*]] = vector.load %[[MEM]][] : memref<f32>, vector<f32>
1110
%0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
1211

13-
// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
14-
// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref<f32>
12+
// CHECK-NEXT: vector.store %[[S]], %[[MEM]][] : memref<f32>, vector<f32>
1513
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
1614

17-
// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
18-
// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
15+
// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<1x1x1xf32>
1916
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
2017

2118
return
@@ -191,8 +188,8 @@ func.func @transfer_perm_map(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf
191188
// CHECK-LABEL: func @transfer_broadcasting(
192189
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
193190
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
194-
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
195-
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32>
191+
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1xf32>
192+
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1xf32> to vector<4xf32>
196193
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
197194
// CHECK-NEXT: }
198195

@@ -208,8 +205,7 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %idx : index) -> vector
208205
// CHECK-LABEL: func @transfer_scalar(
209206
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
210207
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> {
211-
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>
212-
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32>
208+
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>, vector<1xf32>
213209
// CHECK-NEXT: return %[[RES]] : vector<1xf32>
214210
// CHECK-NEXT: }
215211
func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32> {
@@ -222,8 +218,8 @@ func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32
222218
// CHECK-LABEL: func @transfer_broadcasting_2D(
223219
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
224220
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> {
225-
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
226-
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32>
221+
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1x1xf32>
222+
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1x1xf32> to vector<4x4xf32>
227223
// CHECK-NEXT: return %[[RES]] : vector<4x4xf32>
228224
// CHECK-NEXT: }
229225

@@ -322,8 +318,8 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref<
322318
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
323319

324320
%6 = vector.transfer_read %mem_0[%c0, %c0], %cst {in_bounds = [true], permutation_map = #map6} : memref<?x?xf32>, vector<8xf32>
325-
// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>
326-
// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32>
321+
// CHECK: vector.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<1xf32>
322+
// CHECK: vector.broadcast %{{.*}} : vector<1xf32> to vector<8xf32>
327323

328324
return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
329325
vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,

0 commit comments

Comments
 (0)