Skip to content

Commit 2a247bf

Browse files
committed
move patterns to vector-sink
1 parent ac16b30 commit 2a247bf

File tree

10 files changed

+140
-144
lines changed

10 files changed

+140
-144
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,24 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
453453
let assemblyFormat = "attr-dict";
454454
}
455455

456-
def ApplyVectorPropagateExtractPatternsOp : Op<Transform_Dialect,
457-
"apply_patterns.vector.propagate_extract",
456+
def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
457+
"apply_patterns.vector.sink_ops",
458458
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
459459
let description = [{
460-
Collect a set of patterns for propagating `vector.extract` through the
461-
vector ops.
460+
Patterns that remove redundant Vector Ops by re-ordering them with
461+
e.g. elementwise Ops:
462+
```
463+
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
464+
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
465+
%r = arith.addf %at, %bt : vector<2x4xf32>
466+
```
467+
gets converted to:
468+
```
469+
%0 = arith.addf %a, %b : vector<4x2xf32>
470+
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
471+
```
472+
At the moment, these patterns are limited to vector.broadcast and
473+
vector.transpose.
462474
}];
463475

464476
let assemblyFormat = "attr-dict";

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,6 @@ void populateVectorLinearizeShuffleLikeOpsPatterns(
409409
const TypeConverter &typeConverter, RewritePatternSet &patterns,
410410
ConversionTarget &target, unsigned targetBitWidth);
411411

412-
/// Populates patterns for propagating `vector.extract` through the vector ops.
413-
void populateVectorPropagateExtractsPatterns(RewritePatternSet &patterns);
414-
415412
} // namespace vector
416413
} // namespace mlir
417414

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
204204
populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
205205
}
206206

207-
void transform::ApplyVectorPropagateExtractPatternsOp::populatePatterns(
207+
void transform::ApplySinkVectorPatternsOp::populatePatterns(
208208
RewritePatternSet &patterns) {
209-
vector::populateVectorPropagateExtractsPatterns(patterns);
209+
vector::populateSinkVectorOpsPatterns(patterns);
210210
}
211211

212212
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
2424
VectorTransforms.cpp
2525
VectorUnroll.cpp
2626
VectorMaskElimination.cpp
27-
VectorPropagateExtract.cpp
2827

2928
ADDITIONAL_HEADER_DIRS
3029
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms

mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp

Lines changed: 0 additions & 66 deletions
This file was deleted.

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,50 @@ struct ReorderElementwiseOpsOnBroadcast final
10431043
}
10441044
};
10451045

1046+
/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
1047+
/// This may result in more efficient code when we extracting a single value
1048+
/// from multi-element vector and also to help canonicalize 1-element vectors to
1049+
/// scalars.
1050+
class ExtractOpFromElementwise final
1051+
: public OpRewritePattern<vector::ExtractOp> {
1052+
public:
1053+
using OpRewritePattern::OpRewritePattern;
1054+
1055+
LogicalResult matchAndRewrite(vector::ExtractOp op,
1056+
PatternRewriter &rewriter) const override {
1057+
Operation *eltwise = op.getVector().getDefiningOp();
1058+
1059+
// Elementwise op with single result and `extract` is single user.
1060+
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1061+
eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
1062+
return failure();
1063+
1064+
// Arguments and result types must match.
1065+
if (!llvm::all_equal(eltwise->getOperandTypes()))
1066+
return failure();
1067+
1068+
Type dstType = op.getType();
1069+
1070+
OpBuilder::InsertionGuard g(rewriter);
1071+
rewriter.setInsertionPoint(eltwise);
1072+
1073+
IRMapping mapping;
1074+
Location loc = eltwise->getLoc();
1075+
for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
1076+
Value newArg =
1077+
rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
1078+
mapping.map(arg, newArg);
1079+
}
1080+
1081+
Operation *newEltwise = rewriter.clone(*eltwise, mapping);
1082+
newEltwise->getResult(0).setType(dstType);
1083+
1084+
rewriter.replaceOp(op, newEltwise);
1085+
rewriter.eraseOp(eltwise);
1086+
return success();
1087+
}
1088+
};
1089+
10461090
// Helper that returns a vector comparison that constructs a mask:
10471091
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
10481092
//
@@ -2111,8 +2155,8 @@ void mlir::vector::
21112155
void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
21122156
PatternBenefit benefit) {
21132157
patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2114-
ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
2115-
benefit);
2158+
ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2159+
patterns.getContext(), benefit);
21162160
}
21172161

21182162
void mlir::vector::populateChainedVectorReductionFoldingPatterns(

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,15 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
6262
// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
6363
// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
6464
// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
65-
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
66-
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
67-
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index
68-
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 79 : index
69-
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
70-
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
71-
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
72-
// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
73-
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
74-
75-
// CHECK: %[[VAL_19:.*]] = vector.extract %[[VAL_16]][0] : index from vector<4xindex>
76-
77-
// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_11]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
78-
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
65+
66+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
67+
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
68+
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
69+
// CHECK: %[[VAL6:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
70+
// CHECK: %[[VAL7:.*]] = arith.addi %[[VAL_3]], %[[VAL_4]] : index
71+
72+
// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL6]], %[[C79]], %[[VAL7]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
73+
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
7974
// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
8075
// CHECK: }
8176

@@ -101,14 +96,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
10196
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
10297
// CHECK-SAME: %[[VAL_1:.*]]: index,
10398
// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
104-
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
99+
105100
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
106101
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
107102
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 79 : index
108-
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
109-
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
110-
// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex>
111-
// CHECK: %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
103+
104+
// CHECK: %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
112105
// CHECK: %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
113106
// CHECK: return %[[VAL_12]] : tensor<1x4xf32>
114107
// CHECK: }

mlir/test/Dialect/Vector/propagate-extracts.mlir

Lines changed: 0 additions & 47 deletions
This file was deleted.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2+
3+
// This is smoke test for `transform.apply_patterns.vector.sink_ops` the actual
4+
// patterns are tested in `vector-sink.mlir`.
5+
module attributes {transform.with_named_sequence} {
6+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
7+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
8+
transform.apply_patterns to %func {
9+
transform.apply_patterns.vector.sink_ops
10+
} : !transform.any_op
11+
transform.yield
12+
}
13+
}
14+
15+
16+
// CHECK-LABEL: @extract_elementwise
17+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
18+
func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
19+
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
20+
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
21+
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
22+
// CHECK: return %[[RES]] : f32
23+
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
24+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
25+
return %1 : f32
26+
}

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,41 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %
423423
%r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
424424
return %r : vector<6x[4]x2x3xf32>
425425
}
426+
427+
// -----
428+
429+
// CHECK-LABEL: @extract_elementwise
430+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
431+
func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
432+
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
433+
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
434+
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
435+
// CHECK: return %[[RES]] : f32
436+
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
437+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
438+
return %1 : f32
439+
}
440+
441+
// CHECK-LABEL: @extract_vec_elementwise
442+
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
443+
func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
444+
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
445+
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
446+
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
447+
// CHECK: return %[[RES]] : vector<4xf32>
448+
%0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
449+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
450+
return %1 : vector<4xf32>
451+
}
452+
453+
// CHECK-LABEL: @extract_elementwise_use
454+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
455+
func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
456+
// Do not propagate extract, as elementwise has other uses.
457+
// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
458+
// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
459+
// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
460+
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
461+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
462+
return %1, %0 : f32, vector<4xf32>
463+
}

0 commit comments

Comments
 (0)