Skip to content

[mlir][vector] Improve makeArithReduction expansion #75846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Dec 18, 2023

Propagate fast math flags.
Distinguish minf/maxf and minimumf/maximumf.

Required for future patterns in #75727.

Propagate fast math flags.
Distinguish `minf`/`maxf` and `minimumf`/`maximumf`.

Required for future patterns in
llvm#75727.
@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes

Propagate fast math flags.
Distinguish minf/maxf and minimumf/maximumf.

Required for future patterns in #75727.


Full diff: https://github.com/llvm/llvm-project/pull/75846.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+4-2)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+18-7)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir (+8-8)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..a28b27e4e15816 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
                            VectorTransferOpInterface transferB,
                            bool testDynamicValueUsingBounds = false);
 
-/// Return the result value of reducing two scalar/vector values with the
+/// Returns the result value of reducing two scalar/vector values with the
 /// corresponding arith operation.
 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
-                         Value v1, Value acc, Value mask = Value());
+                         Value v1, Value acc,
+                         arith::FastMathFlagsAttr fastmath = nullptr,
+                         Value mask = nullptr);
 
 /// Returns true if `attr` has "parallel" iterator type semantics.
 inline bool isParallelIterator(Attribute attr) {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..9f3e13c90a624d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
                                                 zeroIdx);
     }
 
-    Value result = vector::makeArithReduction(
-        rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+    Value result =
+        vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
+                                   cast, /*fastmath=*/nullptr, mask);
     rewriter.replaceOp(rootOp, result);
     return success();
   }
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
 
     if (Value acc = reductionOp.getAcc())
       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
-                                          result, acc, mask);
+                                          result, acc,
+                                          reductionOp.getFastmathAttr(), mask);
 
     rewriter.replaceOp(rootOp, result);
     return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
 
 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
                                        CombiningKind kind, Value v1, Value acc,
+                                       arith::FastMathFlagsAttr fastmath,
                                        Value mask) {
   Type t1 = getElementTypeOrSelf(v1.getType());
   Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for ADD reduction");
     break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
     break;
   case CombiningKind::MAXF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MAXIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MINF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MINIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MAXSI:
     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for MUL reduction");
     break;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6dbe36e605e9a7..41ff0c18fe6258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
   if (!acc)
     return std::optional<Value>(mul);
 
-  return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+  return makeArithReduction(rewriter, loc, kind, mul, acc,
+                            /*fastmath=*/nullptr, mask);
 }
 
 /// Return the positions of the reductions in the given map.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..7353d16d79cea0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -450,7 +450,7 @@ func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
 // CHECK-LABEL:   func.func @masked_float_max_outerprod(
 // CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
 // CHECK:           %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK:           %[[VAL_9:.*]] = arith.maximumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.maxnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
 
 // -----
@@ -463,7 +463,7 @@ func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
 // CHECK-LABEL:   func.func @masked_float_min_outerprod(
 // CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
 // CHECK:           %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK:           %[[VAL_9:.*]] = arith.minimumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.minnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
 
 // -----
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..b5164b66817352 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+//       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+//       CHECK:   %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
+//       CHECK:   return %[[S]]
+func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
+  %s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @masked_reduce_one_element_vector_addf
 //  CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
 //  CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 12ea87ffb1413f..614a97fe4d6777 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.maximumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV012:.+]] = arith.maximumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RESULT_VEC:.+]] = arith.maximumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not my area of expertise, but looks rather straightforward, thanks! LGTM

Comment on lines +30 to +36
// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
// CHECK: %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
// CHECK: %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
// CHECK: %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused because the RFC says that we should map minf/maxf to minimumf/minimumf.

https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671#proposed-solution-2

However, this is not the case in vector dialect. The execution introduces two combiner kinds minimumf/minimumf and keeps minf/maxf in https://reviews.llvm.org/D158618

Not sure who is working on it, but the current mapping is

<minimumf> --> arith.minimumf
<maximumf> --> arith.maximumf
<maxf> --> arith.maxnumf
<minf> --> arith.minnumf

In this context, the change looks okay to me.

@kuhar kuhar merged commit a528cee into llvm:main Dec 18, 2023
kuhar added a commit to kuhar/llvm-project that referenced this pull request Dec 18, 2023
The number of vector elements considered 'small' enough to extract is
parameterized.

This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction
intrinsics can use that as an emulation path too.

Depends on llvm#75846.

 Please enter the commit message for your changes. Lines starting
kuhar added a commit that referenced this pull request Dec 18, 2023
…75727)

The number of vector elements considered 'small' enough to extract is
parameterized.                                                   
                                                                 
This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction  
intrinsics can use that as an emulation path too.                  
                                                                   
Depends on #75846.
kuhar added a commit to kuhar/llvm-project that referenced this pull request Dec 20, 2023
This is to avoid confusion when dealing with reduction/combining kinds.
For example, see a recent PR comment:
llvm#75846 (comment).

Previously, they were picked to mostly mirror the names of the llvm vector
reduction intrinsics:
https://llvm.org/docs/LangRef.html#llvm-vector-reduce-fmin-intrinsic.
In isolation, it was not clear if `<maxf>` has `arith.maxnumf` or
`arith.maximumf` semantics. The new reduction kind names map 1:1 to
arith ops, which makes it easier to tell/look up their semantics.

Because both the vector and the gpu dialect depend on the arith dialect,
it more natural to align names with those in arith than with the
lowering to llvm intrinsics.

Issue: llvm#72354
kuhar added a commit that referenced this pull request Dec 20, 2023
…75901)

This is to avoid confusion when dealing with reduction/combining kinds.
For example, see a recent PR comment:
#75846 (comment).

Previously, they were picked to mostly mirror the names of the llvm
vector reduction intrinsics:
https://llvm.org/docs/LangRef.html#llvm-vector-reduce-fmin-intrinsic. In
isolation, it was not clear if `<maxf>` has `arith.maxnumf` or
`arith.maximumf` semantics. The new reduction kind names map 1:1 to
arith ops, which makes it easier to tell/look up their semantics.

Because both the vector and the gpu dialect depend on the arith dialect,
it's more natural to align names with those in arith than with the
lowering to llvm intrinsics.

Issue: #72354
qihangkong pushed a commit to rvgpu/rvgpu-llvm that referenced this pull request Apr 23, 2024
…#75901)

This is to avoid confusion when dealing with reduction/combining kinds.
For example, see a recent PR comment:
llvm/llvm-project#75846 (comment).

Previously, they were picked to mostly mirror the names of the llvm
vector reduction intrinsics:
https://llvm.org/docs/LangRef.html#llvm-vector-reduce-fmin-intrinsic. In
isolation, it was not clear if `<maxf>` has `arith.maxnumf` or
`arith.maximumf` semantics. The new reduction kind names map 1:1 to
arith ops, which makes it easier to tell/look up their semantics.

Because both the vector and the gpu dialect depend on the arith dialect,
it's more natural to align names with those in arith than with the
lowering to llvm intrinsics.

Issue: llvm/llvm-project#72354
qihangkong pushed a commit to rvgpu/rvgpu-llvm that referenced this pull request Apr 23, 2024
…75727)

The number of vector elements considered 'small' enough to extract is
parameterized.                                                   
                                                                 
This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction  
intrinsics can use that as an emulation path too.                  
                                                                   
Depends on llvm/llvm-project#75846.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants