Skip to content

Commit 64ea88d

Browse files
committed
review comments
1 parent bf628cc commit 64ea88d

File tree

5 files changed

+44
-46
lines changed

5 files changed

+44
-46
lines changed

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
6767
void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
6868
RewritePatternSet &patterns) {
6969
vector::populateVectorReductionToContractPatterns(patterns);
70+
71+
// TODO: As we now have a dedicated transform for
72+
// `populateSinkVectorOpsPatterns` we can remove it from here.
7073
vector::populateSinkVectorOpsPatterns(patterns);
7174
}
7275

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ struct ReorderElementwiseOpsOnBroadcast final
10441044
};
10451045

10461046
/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
1047-
/// This may result in more efficient code when we extracting a single value
1047+
/// This may result in cleaner code when we extracting a single value
10481048
/// from multi-element vector and also to help canonicalize 1-element vectors to
10491049
/// scalars.
10501050
/// ```
@@ -1066,14 +1066,17 @@ class ExtractOpFromElementwise final
10661066
PatternRewriter &rewriter) const override {
10671067
Operation *eltwise = op.getVector().getDefiningOp();
10681068

1069-
// Elementwise op with single result and `extract` is single user.
1070-
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1071-
eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
1072-
return rewriter.notifyMatchFailure(op, "not a suitable op");
1069+
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise))
1070+
return rewriter.notifyMatchFailure(op, "not an elementwise op");
1071+
1072+
if (eltwise->getNumResults() != 1)
1073+
return rewriter.notifyMatchFailure(op, "expected single result");
1074+
1075+
if (!eltwise->hasOneUse())
1076+
return rewriter.notifyMatchFailure(op, "expected single op use");
10731077

1074-
// Arguments types must match.
10751078
if (!llvm::all_equal(eltwise->getOperandTypes()))
1076-
return rewriter.notifyMatchFailure(op, "arg types are different");
1079+
return rewriter.notifyMatchFailure(op, "operand types are different");
10771080

10781081
Type dstType = op.getType();
10791082

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,19 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
5959

6060

6161
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex(
62-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
63-
// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
64-
// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
62+
// CHECK-SAME: %[[ARG0:.*]]: tensor<45x80x16xf32>,
63+
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index,
64+
// CHECK-SAME: %[[ARG5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
6565

6666
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
6767
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
6868
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
69-
// CHECK: %[[VAL6:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
70-
// CHECK: %[[VAL7:.*]] = arith.addi %[[VAL_3]], %[[VAL_4]] : index
69+
// CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
70+
// CHECK: %[[ADD2:.*]] = arith.addi %[[ARG3]], %[[ARG4]] : index
7171

72-
// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL6]], %[[C79]], %[[VAL7]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
73-
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
74-
// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
72+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ADD1]], %[[C79]], %[[ADD2]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
73+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
74+
// CHECK: return %[[WRITE]] : tensor<1x4xf32>
7575
// CHECK: }
7676

7777
// -----
@@ -93,17 +93,17 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
9393
}
9494

9595
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(
96-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
97-
// CHECK-SAME: %[[VAL_1:.*]]: index,
98-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
96+
// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16xf32>,
97+
// CHECK-SAME: %[[ARG1:.*]]: index,
98+
// CHECK-SAME: %[[ARG2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
9999

100-
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
101-
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
102-
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 79 : index
100+
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
101+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
102+
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
103103

104-
// CHECK: %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
105-
// CHECK: %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
106-
// CHECK: return %[[VAL_12]] : tensor<1x4xf32>
104+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[C79]], %[[ARG1]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
105+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG2]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
106+
// CHECK: return %[[WRITE]] : tensor<1x4xf32>
107107
// CHECK: }
108108

109109
// -----
Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s
22

3-
// This is smoke test for `transform.apply_patterns.vector.sink_ops` the actual
4-
// patterns are tested in `vector-sink.mlir`.
3+
// This is smoke test for `transform.apply_patterns.vector.sink_ops` and this
4+
// file is also used in `vector-sink.mlir`.
55
module attributes {transform.with_named_sequence} {
66
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
77
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
@@ -11,16 +11,3 @@ module attributes {transform.with_named_sequence} {
1111
transform.yield
1212
}
1313
}
14-
15-
16-
// CHECK-LABEL: @extract_elementwise_scalar
17-
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
18-
func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
19-
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
20-
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
21-
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
22-
// CHECK: return %[[RES]] : f32
23-
%0 = arith.addf %arg0, %arg1 : vector<4xf32>
24-
%1 = vector.extract %0[1] : f32 from vector<4xf32>
25-
return %1 : f32
26-
}

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
2+
// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/vector-sink-transform.mlir' -transform-interpreter -split-input-file %s | FileCheck %s
23

34
//-----------------------------------------------------------------------------
45
// [Pattern: ReorderElementwiseOpsOnBroadcast]
@@ -426,6 +427,10 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %
426427

427428
// -----
428429

430+
//-----------------------------------------------------------------------------
431+
// [Pattern: ExtractOpFromElementwise]
432+
//-----------------------------------------------------------------------------
433+
429434
// CHECK-LABEL: @extract_elementwise_scalar
430435
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
431436
func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
@@ -461,9 +466,9 @@ func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32
461466
return %1 : vector<4xf32>
462467
}
463468

464-
// CHECK-LABEL: @extract_elementwise_no_single_use
469+
// CHECK-LABEL: @negative_extract_elementwise_no_single_use
465470
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
466-
func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
471+
func.func @negative_extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
467472
// Do not propagate extract, as elementwise has other uses.
468473
// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
469474
// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
@@ -473,9 +478,9 @@ func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector
473478
return %1, %0 : f32, vector<4xf32>
474479
}
475480

476-
// CHECK-LABEL: @extract_elementwise_not_one_res
481+
// CHECK-LABEL: @negative_extract_elementwise_not_one_res
477482
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>)
478-
func.func @extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
483+
func.func @negative_extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
479484
// Do not propagate extract, as elementwise has more than 1 result.
480485
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32>
481486
// CHECK: %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32>
@@ -485,9 +490,9 @@ func.func @extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4
485490
return %1 : i32
486491
}
487492

488-
// CHECK-LABEL: @extract_not_elementwise
493+
// CHECK-LABEL: @negative_extract_not_elementwise
489494
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi64>)
490-
func.func @extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
495+
func.func @negative_extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
491496
// `test.increment` is not an elemewise op.
492497
// CHECK: %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64>
493498
// CHECK: %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64>

0 commit comments

Comments
 (0)