Skip to content

Commit 6d02f62

Browse files
authored
[mlir][linalg] Add vectorization support for minnumf/maxnumf reductions. (#101092)
This is a follow-up for https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671 The ops were splitted to two version, and the vectorization support for one of them is missing. The revision also renames the existing lit tests accordingly, which explicitly puts `maximumf/minimumf` to the function names.
1 parent 9b14831 commit 6d02f62

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,11 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
522522
.Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
523523
.Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
524524
.Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
525+
.Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
525526
.Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
526527
.Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
527528
.Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
529+
.Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
528530
.Case<arith::MulIOp, arith::MulFOp>(
529531
[&](auto op) { return CombiningKind::MUL; })
530532
.Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })

mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,8 +1240,8 @@ module attributes {transform.with_named_sequence} {
12401240

12411241
// -----
12421242

1243-
// CHECK-LABEL: func @red_max_2d(
1244-
func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
1243+
// CHECK-LABEL: func @red_maximumf_2d(
1244+
func.func @red_maximumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
12451245
// CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
12461246
// CHECK: tensor.empty() : tensor<4xf32>
12471247
// CHECK: vector.multi_reduction <maximumf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
@@ -1272,8 +1272,40 @@ module attributes {transform.with_named_sequence} {
12721272

12731273
// -----
12741274

1275-
// CHECK-LABEL: func @red_min_2d(
1276-
func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
1275+
// CHECK-LABEL: func @red_maxnumf_2d(
1276+
func.func @red_maxnumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
1277+
// CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
1278+
// CHECK: tensor.empty() : tensor<4xf32>
1279+
// CHECK: vector.multi_reduction <maxnumf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
1280+
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
1281+
%ident = arith.constant -3.40282e+38 : f32
1282+
%init = tensor.empty() : tensor<4xf32>
1283+
%fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
1284+
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1285+
affine_map<(d0, d1) -> (d0)>],
1286+
iterator_types = ["parallel", "reduction"]}
1287+
ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
1288+
^bb0(%in0: f32, %out0: f32):
1289+
%max = arith.maxnumf %in0, %out0 : f32
1290+
linalg.yield %max : f32
1291+
} -> tensor<4xf32>
1292+
return %red : tensor<4xf32>
1293+
}
1294+
1295+
1296+
module attributes {transform.with_named_sequence} {
1297+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1298+
%3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1299+
%4 = transform.get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1300+
%5 = transform.structured.vectorize_children_and_apply_patterns %4 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
1301+
transform.yield
1302+
}
1303+
}
1304+
1305+
// -----
1306+
1307+
// CHECK-LABEL: func @red_minimumf_2d(
1308+
func.func @red_minimumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
12771309
// CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32>
12781310
// CHECK: tensor.empty() : tensor<4xf32>
12791311
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
@@ -1294,6 +1326,39 @@ func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
12941326
}
12951327

12961328

1329+
module attributes {transform.with_named_sequence} {
1330+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1331+
%3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1332+
%4 = transform.get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1333+
%5 = transform.structured.vectorize_children_and_apply_patterns %4 : (!transform.any_op) -> !transform.any_op
1334+
transform.yield
1335+
}
1336+
}
1337+
1338+
// -----
1339+
1340+
// CHECK-LABEL: func @red_minnumf_2d(
1341+
func.func @red_minnumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
1342+
// CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32>
1343+
// CHECK: tensor.empty() : tensor<4xf32>
1344+
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
1345+
// CHECK: vector.multi_reduction <minnumf>, {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32>
1346+
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
1347+
%maxf32 = arith.constant 3.40282e+38 : f32
1348+
%init = tensor.empty() : tensor<4xf32>
1349+
%fill = linalg.fill ins(%maxf32 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
1350+
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1351+
affine_map<(d0, d1) -> (d0)>],
1352+
iterator_types = ["parallel", "reduction"]}
1353+
ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
1354+
^bb0(%in0: f32, %out0: f32):
1355+
%min = arith.minnumf %out0, %in0 : f32
1356+
linalg.yield %min : f32
1357+
} -> tensor<4xf32>
1358+
return %red : tensor<4xf32>
1359+
}
1360+
1361+
12971362
module attributes {transform.with_named_sequence} {
12981363
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
12991364
%3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op

0 commit comments

Comments
 (0)