Skip to content

[mlir][linalg] Add vectorization support for minnumf/maxnumf reductions. #101092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Jul 29, 2024

This is a follow-up for https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671

The ops were splitted to two version, and the vectorization support for one of them is missing.

The revision also renames the existing lit tests accordingly, which explicitly puts maximumf/minimumf to the function names.

This is a follow-up for https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671

The ops were splitted to two version, and the vectorization support for
one of them is missing. It renames the existing lit tests accordingly,
which explicitly puts maximumf/minimumf to the function names.
@hanhanW hanhanW changed the title [mlir][linalg] Add vectorization support for minnumf/maxnumf reduction. [mlir][linalg] Add vectorization support for minnumf/maxnumf reductions. Jul 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Han-Chung Wang (hanhanW)

Changes

This is a follow-up for https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671

The ops were splitted to two version, and the vectorization support for one of them is missing.

The revision also renames the existing lit tests accordingly, which explicitly puts maximumf/minimumf to the function names.


Full diff: https://github.com/llvm/llvm-project/pull/101092.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+2)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+69-4)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a4c0508d0d8fa..56f49c5a57a10 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -522,9 +522,11 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
       .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
       .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
       .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
+      .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
       .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
       .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
       .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
+      .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
       .Case<arith::MulIOp, arith::MulFOp>(
           [&](auto op) { return CombiningKind::MUL; })
       .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index d7ff1ded9d933..3404b73102e6a 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1240,8 +1240,8 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL:   func @red_max_2d(
-func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
+// CHECK-LABEL:   func @red_maximumf_2d(
+func.func @red_maximumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
   // CHECK: tensor.empty() : tensor<4xf32>
   // CHECK: vector.multi_reduction <maximumf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
@@ -1272,8 +1272,40 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL:   func @red_min_2d(
-func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
+// CHECK-LABEL:   func @red_maxnumf_2d(
+func.func @red_maxnumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
+  // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
+  // CHECK: tensor.empty() : tensor<4xf32>
+  // CHECK: vector.multi_reduction <maxnumf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
+  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
+  %ident = arith.constant -3.40282e+38 : f32
+  %init = tensor.empty() : tensor<4xf32>
+  %fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0)>],
+                         iterator_types = ["parallel", "reduction"]}
+                         ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
+  ^bb0(%in0: f32, %out0: f32):
+    %max = arith.maxnumf %in0, %out0 : f32
+    linalg.yield %max : f32
+  } -> tensor<4xf32>
+  return %red : tensor<4xf32>
+}
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %4 = transform.get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %5 = transform.structured.vectorize_children_and_apply_patterns %4 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL:   func @red_minimumf_2d(
+func.func @red_minimumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32>
   // CHECK: tensor.empty() : tensor<4xf32>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
@@ -1294,6 +1326,39 @@ func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
 }
 
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %4 = transform.get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %5 = transform.structured.vectorize_children_and_apply_patterns %4 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL:   func @red_minnumf_2d(
+func.func @red_minnumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
+  // CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32>
+  // CHECK: tensor.empty() : tensor<4xf32>
+  // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
+  // CHECK: vector.multi_reduction <minnumf>, {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32>
+  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
+  %maxf32 = arith.constant 3.40282e+38 : f32
+  %init = tensor.empty() : tensor<4xf32>
+  %fill = linalg.fill ins(%maxf32 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0)>],
+                         iterator_types = ["parallel", "reduction"]}
+                         ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
+  ^bb0(%in0: f32, %out0: f32):
+    %min = arith.minnumf %out0, %in0 : f32
+    linalg.yield %min : f32
+  } -> tensor<4xf32>
+  return %red : tensor<4xf32>
+}
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op

@hanhanW hanhanW merged commit 6d02f62 into llvm:main Jul 30, 2024
10 checks passed
@hanhanW hanhanW deleted the vectorization-support-for-minnumf-maxnumf-reduction branch July 30, 2024 03:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants