Skip to content

Commit 9b2c084

Browse files
committed
[mlir][vector] Sink extract/splat into load/store ops
``` vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> vector.extract %0[1] : f32 from vector<4xf32> ``` Gets converted to: ``` %c1 = arith.constant 1 : index %0 = arith.addi %arg1, %c1 overflow<nsw> : index %1 = memref.load %arg0[%0] : memref<?xf32> ``` ``` %0 = vector.splat %arg2 : vector<1xf32> vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> ``` Gets converted to: ``` memref.store %arg2, %arg0[%arg1] : memref<?xf32> ```
1 parent 4fa3b2a commit 9b2c084

File tree

7 files changed

+350
-2
lines changed

7 files changed

+350
-2
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,28 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
469469
%0 = arith.addf %a, %b : vector<4x2xf32>
470470
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
471471
```
472-
At the moment, these patterns are limited to vector.broadcast and
473-
vector.transpose.
472+
At the moment, these patterns are limited to vector.broadcast,
473+
vector.transpose and vector.extract.
474+
}];
475+
476+
let assemblyFormat = "attr-dict";
477+
}
478+
479+
def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
480+
"apply_patterns.vector.sink_mem_ops",
481+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
482+
let description = [{
483+
Patterns that remove redundant Vector Ops by merging them with load/store
484+
ops
485+
```
486+
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
487+
vector.extract %0[1] : f32 from vector<4xf32>
488+
```
489+
Gets converted to:
490+
```
491+
%c1 = arith.constant 1 : index
492+
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
493+
%1 = memref.load %arg0[%0] : memref<?xf32>
474494
}];
475495

476496
let assemblyFormat = "attr-dict";

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
161161
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
162162
PatternBenefit benefit = 1);
163163

164+
/// Patterns that remove redundant Vector Ops by re-ordering them with
165+
/// memory Ops:
166+
void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
167+
PatternBenefit benefit = 1);
168+
164169
/// Patterns that fold chained vector reductions. These patterns assume that
165170
/// elementwise operations (e.g., `arith.addf` with vector operands) are
166171
/// cheaper than vector reduction.

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
212212
vector::populateSinkVectorOpsPatterns(patterns);
213213
}
214214

215+
void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
216+
RewritePatternSet &patterns) {
217+
vector::populateSinkVectorMemOpsPatterns(patterns);
218+
}
219+
215220
//===----------------------------------------------------------------------===//
216221
// Transform op registration
217222
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,127 @@ class ExtractOpFromElementwise final
11031103
}
11041104
};
11051105

1106+
/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
1107+
/// ```
1108+
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1109+
/// vector.extract %0[1] : f32 from vector<4xf32>
1110+
/// ```
1111+
/// Gets converted to:
1112+
/// ```
1113+
/// %c1 = arith.constant 1 : index
1114+
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1115+
/// %1 = memref.load %arg0[%0] : memref<?xf32>
1116+
/// ```
1117+
class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1118+
public:
1119+
using OpRewritePattern::OpRewritePattern;
1120+
1121+
LogicalResult matchAndRewrite(vector::ExtractOp op,
1122+
PatternRewriter &rewriter) const override {
1123+
auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
1124+
if (!loadOp)
1125+
return rewriter.notifyMatchFailure(op, "not a load op");
1126+
1127+
if (!loadOp->hasOneUse())
1128+
return rewriter.notifyMatchFailure(op, "expected single op use");
1129+
1130+
VectorType memVecType = loadOp.getVectorType();
1131+
if (memVecType.isScalable())
1132+
return rewriter.notifyMatchFailure(op,
1133+
"scalable vectors are not supported");
1134+
1135+
MemRefType memType = loadOp.getMemRefType();
1136+
if (isa<VectorType>(memType.getElementType()))
1137+
return rewriter.notifyMatchFailure(
1138+
op, "memrefs of vectors are not supported");
1139+
1140+
int64_t rankOffset = memType.getRank() - memVecType.getRank();
1141+
if (rankOffset < 0)
1142+
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
1143+
1144+
auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
1145+
int64_t finalRank = 0;
1146+
if (resVecType)
1147+
finalRank = resVecType.getRank();
1148+
1149+
SmallVector<Value> indices = loadOp.getIndices();
1150+
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1151+
1152+
OpBuilder::InsertionGuard g(rewriter);
1153+
rewriter.setInsertionPoint(loadOp);
1154+
Location loc = loadOp.getLoc();
1155+
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1156+
OpFoldResult pos = extractPos[i - rankOffset];
1157+
if (isConstantIntValue(pos, 0))
1158+
continue;
1159+
1160+
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1161+
1162+
auto ovf = arith::IntegerOverflowFlags::nsw;
1163+
indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
1164+
}
1165+
1166+
Value base = loadOp.getBase();
1167+
if (resVecType) {
1168+
rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
1169+
indices);
1170+
} else {
1171+
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
1172+
}
1173+
rewriter.eraseOp(loadOp);
1174+
return success();
1175+
}
1176+
};
1177+
1178+
/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1179+
/// ```
1180+
/// %0 = vector.splat %arg2 : vector<1xf32>
1181+
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1182+
/// ```
1183+
/// Gets converted to:
1184+
/// ```
1185+
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1186+
/// ```
1187+
class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
1188+
public:
1189+
using OpRewritePattern::OpRewritePattern;
1190+
1191+
LogicalResult matchAndRewrite(vector::StoreOp op,
1192+
PatternRewriter &rewriter) const override {
1193+
VectorType vecType = op.getVectorType();
1194+
if (vecType.isScalable())
1195+
return rewriter.notifyMatchFailure(op,
1196+
"scalable vectors are not supported");
1197+
1198+
if (isa<VectorType>(op.getMemRefType().getElementType()))
1199+
return rewriter.notifyMatchFailure(
1200+
op, "memrefs of vectors are not supported");
1201+
1202+
if (vecType.getNumElements() != 1)
1203+
return rewriter.notifyMatchFailure(
1204+
op, "only 1-element, vectors are supported");
1205+
1206+
Operation *splat = op.getValueToStore().getDefiningOp();
1207+
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1208+
return rewriter.notifyMatchFailure(op, "not a splat");
1209+
1210+
if (!splat->hasOneUse())
1211+
return rewriter.notifyMatchFailure(op, "expected single op use");
1212+
1213+
Value source = splat->getOperand(0);
1214+
Value base = op.getBase();
1215+
ValueRange indices = op.getIndices();
1216+
1217+
if (isa<VectorType>(source.getType())) {
1218+
rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
1219+
} else {
1220+
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
1221+
}
1222+
rewriter.eraseOp(splat);
1223+
return success();
1224+
}
1225+
};
1226+
11061227
// Helper that returns a vector comparison that constructs a mask:
11071228
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
11081229
//
@@ -2175,6 +2296,12 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
21752296
patterns.getContext(), benefit);
21762297
}
21772298

2299+
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2300+
PatternBenefit benefit) {
2301+
patterns.add<ExtractOpFromLoad, StoreFromSplat>(patterns.getContext(),
2302+
benefit);
2303+
}
2304+
21782305
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
21792306
RewritePatternSet &patterns, PatternBenefit benefit) {
21802307
patterns.add<ChainedReduction>(patterns.getContext(), benefit);

mlir/test/Dialect/Vector/vector-sink-transform.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
77
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
88
transform.apply_patterns to %func {
99
transform.apply_patterns.vector.sink_ops
10+
transform.apply_patterns.vector.sink_mem_ops
1011
} : !transform.any_op
1112
transform.yield
1213
}

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
513513
%1 = vector.extract %0[1] : f32 from vector<4xf32>
514514
return %1 : f32
515515
}
516+
517+
//-----------------------------------------------------------------------------
518+
// [Pattern: ExtractOpFromLoad]
519+
//-----------------------------------------------------------------------------
520+
521+
// CHECK-LABEL: @extract_load_scalar
522+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
523+
func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
524+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
525+
// CHECK: return %[[RES]] : f32
526+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
527+
%1 = vector.extract %0[0] : f32 from vector<4xf32>
528+
return %1 : f32
529+
}
530+
531+
// CHECK-LABEL: @extract_load_scalar_non_zero_off
532+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
533+
func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
534+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
535+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
536+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
537+
// CHECK: return %[[RES]] : f32
538+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
539+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
540+
return %1 : f32
541+
}
542+
543+
// CHECK-LABEL: @extract_load_scalar_dyn_off
544+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
545+
func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
546+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
547+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
548+
// CHECK: return %[[RES]] : f32
549+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
550+
%1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
551+
return %1 : f32
552+
}
553+
554+
// CHECK-LABEL: @extract_load_vec
555+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
556+
func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
557+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
558+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
559+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
560+
// CHECK: return %[[RES]] : vector<4xf32>
561+
%0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
562+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
563+
return %1 : vector<4xf32>
564+
}
565+
566+
// CHECK-LABEL: @extract_load_scalar_high_rank
567+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
568+
func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
569+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
570+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
571+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
572+
// CHECK: return %[[RES]] : f32
573+
%0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
574+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
575+
return %1 : f32
576+
}
577+
578+
// CHECK-LABEL: @extract_load_vec_high_rank
579+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
580+
func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
581+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
582+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
583+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
584+
// CHECK: return %[[RES]] : vector<4xf32>
585+
%0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
586+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
587+
return %1 : vector<4xf32>
588+
}
589+
590+
// CHECK-LABEL: @negative_load_scalar_from_vec_memref
591+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
592+
func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
593+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
594+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
595+
// CHECK: return %[[EXT]] : f32
596+
%0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
597+
%1 = vector.extract %0[0] : f32 from vector<4xf32>
598+
return %1 : f32
599+
}
600+
601+
// CHECK-LABEL: @negative_extract_load_no_single_use
602+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
603+
func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
604+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
605+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
606+
// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32>
607+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
608+
%1 = vector.extract %0[0] : f32 from vector<4xf32>
609+
return %1, %0 : f32, vector<4xf32>
610+
}
611+
612+
// CHECK-LABEL: @negative_load_scalable
613+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
614+
func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
615+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
616+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
617+
// CHECK: return %[[EXT]] : f32
618+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
619+
%1 = vector.extract %0[0] : f32 from vector<[1]xf32>
620+
return %1 : f32
621+
}
622+
623+
// CHECK-LABEL: @negative_extract_load_unsupported_ranks
624+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
625+
func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
626+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
627+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
628+
// CHECK: return %[[EXT]] : vector<4xf32>
629+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
630+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
631+
return %1 : vector<4xf32>
632+
}
633+
634+
//-----------------------------------------------------------------------------
635+
// [Pattern: StoreFromSplat]
636+
//-----------------------------------------------------------------------------
637+
638+
// CHECK-LABEL: @store_splat
639+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
640+
func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
641+
// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
642+
%0 = vector.splat %arg2 : vector<1xf32>
643+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
644+
return
645+
}
646+
647+
// CHECK-LABEL: @store_broadcast
648+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
649+
func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
650+
// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
651+
%0 = vector.broadcast %arg2 : f32 to vector<1xf32>
652+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
653+
return
654+
}
655+
656+
// CHECK-LABEL: @store_broadcast_1d_2d
657+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
658+
func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
659+
// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
660+
%0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
661+
vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
662+
return
663+
}
664+
665+
// CHECK-LABEL: @negative_store_scalable
666+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
667+
func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
668+
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
669+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
670+
%0 = vector.splat %arg2 : vector<[1]xf32>
671+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
672+
return
673+
}
674+
675+
// CHECK-LABEL: @negative_store_vec_memref
676+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
677+
func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
678+
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
679+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
680+
%0 = vector.splat %arg2 : vector<1xf32>
681+
vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
682+
return
683+
}
684+
685+
// CHECK-LABEL: @negative_store_non_1
686+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
687+
func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
688+
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
689+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
690+
%0 = vector.splat %arg2 : vector<4xf32>
691+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
692+
return
693+
}
694+
695+
// CHECK-LABEL: @negative_store_no_single_use
696+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
697+
func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
698+
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
699+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
700+
// CHECK: return %[[RES:.*]] : vector<1xf32>
701+
%0 = vector.splat %arg2 : vector<1xf32>
702+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
703+
return %0 : vector<1xf32>
704+
}

0 commit comments

Comments
 (0)