Skip to content

Commit c2ddc12

Browse files
committed
review comments
1 parent e2dd80a commit c2ddc12

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
902902
};
903903

904904
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
905+
///
906+
/// Example:
905907
/// ```
906908
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
907909
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final
987989
/// This may result in cleaner code when extracting a single value
988990
/// from multi-element vector and also to help canonicalize 1-element vectors to
989991
/// scalars.
992+
///
993+
/// Example:
990994
/// ```
991995
/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
992996
/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
@@ -1044,6 +1048,8 @@ class ExtractOpFromElementwise final
10441048
};
10451049

10461050
/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
1051+
///
1052+
/// Example:
10471053
/// ```
10481054
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
10491055
/// vector.extract %0[1] : f32 from vector<4xf32>
@@ -1062,13 +1068,14 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
10621068
PatternRewriter &rewriter) const override {
10631069
auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
10641070
if (!loadOp)
1065-
return rewriter.notifyMatchFailure(op, "not a load op");
1071+
return rewriter.notifyMatchFailure(op, "expected a load op");
10661072

1073+
// Checking for single use so we won't duplicate load ops.
10671074
if (!loadOp->hasOneUse())
10681075
return rewriter.notifyMatchFailure(op, "expected single op use");
10691076

1070-
VectorType memVecType = loadOp.getVectorType();
1071-
if (memVecType.isScalable())
1077+
VectorType loadVecType = loadOp.getVectorType();
1078+
if (loadVecType.isScalable())
10721079
return rewriter.notifyMatchFailure(op,
10731080
"scalable vectors are not supported");
10741081

@@ -1077,7 +1084,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
10771084
return rewriter.notifyMatchFailure(
10781085
op, "memrefs of vectors are not supported");
10791086

1080-
int64_t rankOffset = memType.getRank() - memVecType.getRank();
1087+
int64_t rankOffset = memType.getRank() - loadVecType.getRank();
10811088
if (rankOffset < 0)
10821089
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
10831090

@@ -1089,6 +1096,9 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
10891096
SmallVector<Value> indices = loadOp.getIndices();
10901097
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
10911098

1099+
// There may be memory stores between the load and the extract op, so we
1100+
// need to make sure that the new load op is inserted at the same place as
1101+
// the original load op.
10921102
OpBuilder::InsertionGuard g(rewriter);
10931103
rewriter.setInsertionPoint(loadOp);
10941104
Location loc = loadOp.getLoc();
@@ -1110,12 +1120,15 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
11101120
} else {
11111121
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
11121122
}
1123+
// We checked for single use so we can safely erase the load op.
11131124
rewriter.eraseOp(loadOp);
11141125
return success();
11151126
}
11161127
};
11171128

11181129
/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1130+
///
1131+
/// Example:
11191132
/// ```
11201133
/// %0 = vector.splat %arg2 : vector<1xf32>
11211134
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
@@ -1145,8 +1158,9 @@ class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
11451158

11461159
Operation *splat = op.getValueToStore().getDefiningOp();
11471160
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1148-
return rewriter.notifyMatchFailure(op, "not a splat");
1161+
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
11491162

1163+
// Checking for single use so we can remove splat.
11501164
if (!splat->hasOneUse())
11511165
return rewriter.notifyMatchFailure(op, "expected single op use");
11521166

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,9 @@ func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %a
587587
return %1 : vector<4xf32>
588588
}
589589

590-
// CHECK-LABEL: @negative_load_scalar_from_vec_memref
590+
// CHECK-LABEL: @negative_extract_load_scalar_from_vec_memref
591591
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
592-
func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
592+
func.func @negative_extract_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
593593
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
594594
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
595595
// CHECK: return %[[EXT]] : f32
@@ -609,9 +609,9 @@ func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: inde
609609
return %1, %0 : f32, vector<4xf32>
610610
}
611611

612-
// CHECK-LABEL: @negative_load_scalable
612+
// CHECK-LABEL: @negative_extract_load_scalable
613613
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
614-
func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
614+
func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
615615
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
616616
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
617617
// CHECK: return %[[EXT]] : f32
@@ -620,17 +620,6 @@ func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
620620
return %1 : f32
621621
}
622622

623-
// CHECK-LABEL: @negative_extract_load_unsupported_ranks
624-
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
625-
func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
626-
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
627-
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
628-
// CHECK: return %[[EXT]] : vector<4xf32>
629-
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
630-
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
631-
return %1 : vector<4xf32>
632-
}
633-
634623
//-----------------------------------------------------------------------------
635624
// [Pattern: StoreFromSplat]
636625
//-----------------------------------------------------------------------------
@@ -653,9 +642,9 @@ func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
653642
return
654643
}
655644

656-
// CHECK-LABEL: @store_broadcast_1d_2d
645+
// CHECK-LABEL: @store_broadcast_1d_to_2d
657646
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
658-
func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
647+
func.func @store_broadcast_1d_to_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
659648
// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
660649
%0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
661650
vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
@@ -682,9 +671,9 @@ func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: inde
682671
return
683672
}
684673

685-
// CHECK-LABEL: @negative_store_non_1
674+
// CHECK-LABEL: @negative_store_more_than_one_element
686675
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
687-
func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
676+
func.func @negative_store_more_than_one_element(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
688677
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
689678
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
690679
%0 = vector.splat %arg2 : vector<4xf32>

0 commit comments

Comments
 (0)