Skip to content

Commit 9b02222

Browse files
authored
[mlir][vector] Propagate vector.extract through elementwise ops (#131462)
Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`. Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional elementwise ops.
1 parent 2682a94 commit 9b02222

File tree

6 files changed

+220
-33
lines changed

6 files changed

+220
-33
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,27 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
453453
let assemblyFormat = "attr-dict";
454454
}
455455

456+
def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
457+
"apply_patterns.vector.sink_ops",
458+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
459+
let description = [{
460+
Patterns that remove redundant Vector Ops by re-ordering them with
461+
e.g. elementwise Ops:
462+
```
463+
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
464+
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
465+
%r = arith.addf %at, %bt : vector<2x4xf32>
466+
```
467+
gets converted to:
468+
```
469+
%0 = arith.addf %a, %b : vector<4x2xf32>
470+
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
471+
```
472+
At the moment, these patterns are limited to vector.broadcast and
473+
vector.transpose.
474+
}];
475+
476+
let assemblyFormat = "attr-dict";
477+
}
478+
456479
#endif // VECTOR_TRANSFORM_OPS

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
6767
void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
6868
RewritePatternSet &patterns) {
6969
vector::populateVectorReductionToContractPatterns(patterns);
70+
71+
// TODO: As we now have a dedicated transform for
72+
// `populateSinkVectorOpsPatterns` we can remove it from here.
7073
vector::populateSinkVectorOpsPatterns(patterns);
7174
}
7275

@@ -204,6 +207,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
204207
populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
205208
}
206209

210+
void transform::ApplySinkVectorPatternsOp::populatePatterns(
211+
RewritePatternSet &patterns) {
212+
vector::populateSinkVectorOpsPatterns(patterns);
213+
}
214+
207215
//===----------------------------------------------------------------------===//
208216
// Transform op registration
209217
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,66 @@ struct ReorderElementwiseOpsOnBroadcast final
10431043
}
10441044
};
10451045

1046+
/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
1047+
/// This may result in cleaner code when extracting a single value
1048+
/// from multi-element vector and also to help canonicalize 1-element vectors to
1049+
/// scalars.
1050+
/// ```
1051+
/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
1052+
/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
1053+
/// ```
1054+
/// Gets converted to:
1055+
/// ```
1056+
/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
1057+
/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
1058+
/// %2 = arith.addf %0, %1 : f32
1059+
/// ```
1060+
class ExtractOpFromElementwise final
1061+
: public OpRewritePattern<vector::ExtractOp> {
1062+
public:
1063+
using OpRewritePattern::OpRewritePattern;
1064+
1065+
LogicalResult matchAndRewrite(vector::ExtractOp op,
1066+
PatternRewriter &rewriter) const override {
1067+
Operation *eltwise = op.getVector().getDefiningOp();
1068+
1069+
// TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
1070+
// as it doesn't support scalars.
1071+
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1072+
isa<vector::FMAOp>(eltwise))
1073+
return rewriter.notifyMatchFailure(op, "not an elementwise op");
1074+
1075+
if (eltwise->getNumResults() != 1)
1076+
return rewriter.notifyMatchFailure(op, "expected single result");
1077+
1078+
if (!eltwise->hasOneUse())
1079+
return rewriter.notifyMatchFailure(op, "expected single op use");
1080+
1081+
if (!llvm::all_equal(eltwise->getOperandTypes()))
1082+
return rewriter.notifyMatchFailure(op, "operand types are different");
1083+
1084+
Type dstType = op.getType();
1085+
1086+
OpBuilder::InsertionGuard g(rewriter);
1087+
rewriter.setInsertionPoint(eltwise);
1088+
1089+
IRMapping mapping;
1090+
Location loc = eltwise->getLoc();
1091+
SmallVector<OpFoldResult> pos = op.getMixedPosition();
1092+
for (Value arg : eltwise->getOperands()) {
1093+
Value newArg = rewriter.create<vector::ExtractOp>(loc, arg, pos);
1094+
mapping.map(arg, newArg);
1095+
}
1096+
1097+
Operation *newEltwise = rewriter.clone(*eltwise, mapping);
1098+
newEltwise->getResult(0).setType(dstType);
1099+
1100+
rewriter.replaceOp(op, newEltwise);
1101+
rewriter.eraseOp(eltwise);
1102+
return success();
1103+
}
1104+
};
1105+
10461106
// Helper that returns a vector comparison that constructs a mask:
10471107
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
10481108
//
@@ -2111,8 +2171,8 @@ void mlir::vector::
21112171
void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
21122172
PatternBenefit benefit) {
21132173
patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2114-
ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
2115-
benefit);
2174+
ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2175+
patterns.getContext(), benefit);
21162176
}
21172177

21182178
void mlir::vector::populateChainedVectorReductionFoldingPatterns(

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,19 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
5959

6060

6161
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex(
62-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
63-
// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
64-
// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
65-
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
66-
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
67-
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index
68-
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 79 : index
69-
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
70-
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
71-
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
72-
// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
73-
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
74-
75-
// CHECK: %[[VAL_19:.*]] = vector.extract %[[VAL_16]][0] : index from vector<4xindex>
76-
77-
// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_11]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
78-
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
79-
// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
62+
// CHECK-SAME: %[[ARG0:.*]]: tensor<45x80x16xf32>,
63+
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index,
64+
// CHECK-SAME: %[[ARG5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
65+
66+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
67+
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
68+
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
69+
// CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
70+
// CHECK: %[[ADD2:.*]] = arith.addi %[[ARG3]], %[[ARG4]] : index
71+
72+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ADD1]], %[[C79]], %[[ADD2]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
73+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
74+
// CHECK: return %[[WRITE]] : tensor<1x4xf32>
8075
// CHECK: }
8176

8277
// -----
@@ -98,19 +93,17 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
9893
}
9994

10095
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(
101-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
102-
// CHECK-SAME: %[[VAL_1:.*]]: index,
103-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
104-
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
105-
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
106-
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
107-
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 79 : index
108-
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
109-
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
110-
// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex>
111-
// CHECK: %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
112-
// CHECK: %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
113-
// CHECK: return %[[VAL_12]] : tensor<1x4xf32>
96+
// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16xf32>,
97+
// CHECK-SAME: %[[ARG1:.*]]: index,
98+
// CHECK-SAME: %[[ARG2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
99+
100+
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
101+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
102+
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
103+
104+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[C79]], %[[ARG1]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
105+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG2]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
106+
// CHECK: return %[[WRITE]] : tensor<1x4xf32>
114107
// CHECK: }
115108

116109
// -----
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt %s
2+
3+
// This is smoke test for `transform.apply_patterns.vector.sink_ops` and this
4+
// file is also used in `vector-sink.mlir`.
5+
module attributes {transform.with_named_sequence} {
6+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
7+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
8+
transform.apply_patterns to %func {
9+
transform.apply_patterns.vector.sink_ops
10+
} : !transform.any_op
11+
transform.yield
12+
}
13+
}

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
2+
// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/vector-sink-transform.mlir' -transform-interpreter -split-input-file %s | FileCheck %s
23

34
//-----------------------------------------------------------------------------
45
// [Pattern: ReorderElementwiseOpsOnBroadcast]
@@ -423,3 +424,92 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %
423424
%r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
424425
return %r : vector<6x[4]x2x3xf32>
425426
}
427+
428+
// -----
429+
430+
//-----------------------------------------------------------------------------
431+
// [Pattern: ExtractOpFromElementwise]
432+
//-----------------------------------------------------------------------------
433+
434+
// CHECK-LABEL: @extract_elementwise_scalar
435+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
436+
func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
437+
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
438+
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
439+
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
440+
// CHECK: return %[[RES]] : f32
441+
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
442+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
443+
return %1 : f32
444+
}
445+
446+
// CHECK-LABEL: @extract_elementwise_arg_res_different_types
447+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xindex>)
448+
func.func @extract_elementwise_arg_res_different_types(%arg0: vector<4xindex>) -> i64 {
449+
// CHECK: %[[EXT:.*]] = vector.extract %[[ARG0]][1] : index from vector<4xindex>
450+
// CHECK: %[[RES:.*]] = arith.index_cast %[[EXT]] : index to i64
451+
// CHECK: return %[[RES]] : i64
452+
%0 = arith.index_cast %arg0: vector<4xindex> to vector<4xi64>
453+
%1 = vector.extract %0[1] : i64 from vector<4xi64>
454+
return %1 : i64
455+
}
456+
457+
// CHECK-LABEL: @extract_elementwise_vec
458+
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
459+
func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
460+
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
461+
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
462+
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
463+
// CHECK: return %[[RES]] : vector<4xf32>
464+
%0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
465+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
466+
return %1 : vector<4xf32>
467+
}
468+
469+
// CHECK-LABEL: @negative_extract_elementwise_no_single_use
470+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
471+
func.func @negative_extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
472+
// Do not propagate extract, as elementwise has other uses.
473+
// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
474+
// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
475+
// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
476+
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
477+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
478+
return %1, %0 : f32, vector<4xf32>
479+
}
480+
481+
// CHECK-LABEL: @negative_extract_elementwise_not_one_res
482+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>)
483+
func.func @negative_extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
484+
// Do not propagate extract, as elementwise has more than 1 result.
485+
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32>
486+
// CHECK: %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32>
487+
// CHECK: return %[[EXT]] : i32
488+
%low, %hi = arith.mulsi_extended %arg0, %arg1 : vector<4xi32>
489+
%1 = vector.extract %low[1] : i32 from vector<4xi32>
490+
return %1 : i32
491+
}
492+
493+
// CHECK-LABEL: @negative_extract_not_elementwise
494+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi64>)
495+
func.func @negative_extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
496+
// `test.increment` is not an elemewise op.
497+
// CHECK: %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64>
498+
// CHECK: %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64>
499+
// CHECK: return %[[RES]] : i64
500+
%0 = test.increment %arg0: vector<4xi64>
501+
%1 = vector.extract %0[1] : i64 from vector<4xi64>
502+
return %1 : i64
503+
}
504+
505+
// CHECK-LABEL: @negative_extract_vec_fma
506+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<4xf32>)
507+
func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> f32 {
508+
// `vector.fma` doesn't suppport scalars.
509+
// CHECK: %[[FMA:.*]] = vector.fma %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<4xf32>
510+
// CHECK: %[[RES:.*]] = vector.extract %[[FMA]][1] : f32 from vector<4xf32>
511+
// CHECK: return %[[RES]] : f32
512+
%0 = vector.fma %arg0, %arg1, %arg2: vector<4xf32>
513+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
514+
return %1 : f32
515+
}

0 commit comments

Comments
 (0)