Skip to content

Commit 1f1c284

Browse files
committed
style fixes
1 parent 2a247bf commit 1f1c284

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,16 @@ struct ReorderElementwiseOpsOnBroadcast final
10471047
/// This may result in more efficient code when we extracting a single value
10481048
/// from multi-element vector and also to help canonicalize 1-element vectors to
10491049
/// scalars.
1050+
/// ```
1051+
/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
1052+
/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
1053+
/// ```
1054+
/// Gets converted to:
1055+
/// ```
1056+
/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
1057+
/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
1058+
/// %2 = arith.addf %0, %1 : f32
1059+
/// ```
10501060
class ExtractOpFromElementwise final
10511061
: public OpRewritePattern<vector::ExtractOp> {
10521062
public:
@@ -1061,7 +1071,7 @@ class ExtractOpFromElementwise final
10611071
eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
10621072
return failure();
10631073

1064-
// Arguments and result types must match.
1074+
// Arguments types must match.
10651075
if (!llvm::all_equal(eltwise->getOperandTypes()))
10661076
return failure();
10671077

@@ -1072,7 +1082,7 @@ class ExtractOpFromElementwise final
10721082

10731083
IRMapping mapping;
10741084
Location loc = eltwise->getLoc();
1075-
for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
1085+
for (auto arg : eltwise->getOperands()) {
10761086
Value newArg =
10771087
rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
10781088
mapping.map(arg, newArg);

mlir/test/Dialect/Vector/vector-sink-transform.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ module attributes {transform.with_named_sequence} {
1313
}
1414

1515

16-
// CHECK-LABEL: @extract_elementwise
16+
// CHECK-LABEL: @extract_elementwise_scalar
1717
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
18-
func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
18+
func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
1919
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
2020
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
2121
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,9 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %
426426

427427
// -----
428428

429-
// CHECK-LABEL: @extract_elementwise
429+
// CHECK-LABEL: @extract_elementwise_scalar
430430
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
431-
func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
431+
func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
432432
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
433433
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
434434
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
@@ -438,9 +438,9 @@ func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f3
438438
return %1 : f32
439439
}
440440

441-
// CHECK-LABEL: @extract_vec_elementwise
441+
// CHECK-LABEL: @extract_elementwise_vec
442442
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
443-
func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
443+
func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
444444
// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
445445
// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
446446
// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
@@ -450,9 +450,9 @@ func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32
450450
return %1 : vector<4xf32>
451451
}
452452

453-
// CHECK-LABEL: @extract_elementwise_use
453+
// CHECK-LABEL: @extract_elementwise_no_single_use
454454
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
455-
func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
455+
func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
456456
// Do not propagate extract, as elementwise has other uses.
457457
// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
458458
// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>

0 commit comments

Comments
 (0)