-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Improve makeArithReduction
expansion
#75846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Propagate fast math flags. Distinguish `minf`/`maxf` and `minimumf`/`maximumf`. Required for future patterns in llvm#75727.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) ChangesPropagate fast math flags. Required for future patterns in #75727. Full diff: https://github.com/llvm/llvm-project/pull/75846.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..a28b27e4e15816 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB,
bool testDynamicValueUsingBounds = false);
-/// Return the result value of reducing two scalar/vector values with the
+/// Returns the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
- Value v1, Value acc, Value mask = Value());
+ Value v1, Value acc,
+ arith::FastMathFlagsAttr fastmath = nullptr,
+ Value mask = nullptr);
/// Returns true if `attr` has "parallel" iterator type semantics.
inline bool isParallelIterator(Attribute attr) {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..9f3e13c90a624d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
zeroIdx);
}
- Value result = vector::makeArithReduction(
- rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+ Value result =
+ vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
+ cast, /*fastmath=*/nullptr, mask);
rewriter.replaceOp(rootOp, result);
return success();
}
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
- result, acc, mask);
+ result, acc,
+ reductionOp.getFastmathAttr(), mask);
rewriter.replaceOp(rootOp, result);
return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
+ arith::FastMathFlagsAttr fastmath,
Value mask) {
Type t1 = getElementTypeOrSelf(v1.getType());
Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
- result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
else
llvm_unreachable("invalid value types for ADD reduction");
break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
break;
case CombiningKind::MAXF:
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+ "expected float values");
+ result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
+ break;
case CombiningKind::MAXIMUMF:
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
- result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
break;
case CombiningKind::MINF:
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+ "expected float values");
+ result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
+ break;
case CombiningKind::MINIMUMF:
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
- result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
break;
case CombiningKind::MAXSI:
assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
- result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
else
llvm_unreachable("invalid value types for MUL reduction");
break;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6dbe36e605e9a7..41ff0c18fe6258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
if (!acc)
return std::optional<Value>(mul);
- return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+ return makeArithReduction(rewriter, loc, kind, mul, acc,
+ /*fastmath=*/nullptr, mask);
}
/// Return the positions of the reductions in the given map.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..7353d16d79cea0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -450,7 +450,7 @@ func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
// CHECK-LABEL: func.func @masked_float_max_outerprod(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK: %[[VAL_9:.*]] = arith.maximumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
// -----
@@ -463,7 +463,7 @@ func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
// CHECK-LABEL: func.func @masked_float_min_outerprod(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK: %[[VAL_9:.*]] = arith.minimumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.minnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
// -----
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..b5164b66817352 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
// -----
+// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
+// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+// CHECK: %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
+// CHECK: return %[[S]]
+func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
+ %s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
+ return %s : f32
+}
+
+// -----
+
// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
// CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 12ea87ffb1413f..614a97fe4d6777 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.maximumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.maximumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.maximumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not my area of expertise, but looks rather straightforward, thanks! LGTM
// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32> | ||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32> | ||
// CHECK: %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32> | ||
// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32> | ||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32> | ||
// CHECK: %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32> | ||
// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32> | ||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32> | ||
// CHECK: %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32> | ||
// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was confused because the RFC says that we should map minf/maxf
to minimumf/minimumf
.
However, this is not the case in vector dialect. The execution introduces two combiner kinds minimumf/minimumf
and keeps minf/maxf
in https://reviews.llvm.org/D158618
Not sure who is working on it, but the current mapping is
<minimumf> --> arith.minimumf
<maximumf> --> arith.maximumf
<maxf> --> arith.maxnumf
<minf> --> arith.minnumf
In this context, the change looks okay to me.
The number of vector elements considered 'small' enough to extract is parameterized. This is to avoid going into specialized reduction lowering when a single/couple of arith ops can do. Targets without dedicated reduction intrinsics can use that as an emulation path too. Depends on llvm#75846. Please enter the commit message for your changes. Lines starting
This is to avoid confusion when dealing with reduction/combining kinds. For example, see a recent PR comment: llvm#75846 (comment). Previously, they were picked to mostly mirror the names of the llvm vector reduction intrinsics: https://llvm.org/docs/LangRef.html#llvm-vector-reduce-fmin-intrinsic. In isolation, it was not clear if `<maxf>` has `arith.maxnumf` or `arith.maximumf` semantics. The new reduction kind names map 1:1 to arith ops, which makes it easier to tell/look up their semantics. Because both the vector and the gpu dialect depend on the arith dialect, it more natural to align names with those in arith than with the lowering to llvm intrinsics. Issue: llvm#72354
…75901) This is to avoid confusion when dealing with reduction/combining kinds. For example, see a recent PR comment: #75846 (comment). Previously, they were picked to mostly mirror the names of the llvm vector reduction intrinsics: https://llvm.org/docs/LangRef.html#llvm-vector-reduce-fmin-intrinsic. In isolation, it was not clear if `<maxf>` has `arith.maxnumf` or `arith.maximumf` semantics. The new reduction kind names map 1:1 to arith ops, which makes it easier to tell/look up their semantics. Because both the vector and the gpu dialect depend on the arith dialect, it's more natural to align names with those in arith than with the lowering to llvm intrinsics. Issue: #72354
…#75901) This is to avoid confusion when dealing with reduction/combining kinds. For example, see a recent PR comment: llvm/llvm-project#75846 (comment). Previously, they were picked to mostly mirror the names of the llvm vector reduction intrinsics: https://llvm.org/docs/LangRef.html#llvm-vector-reduce-fmin-intrinsic. In isolation, it was not clear if `<maxf>` has `arith.maxnumf` or `arith.maximumf` semantics. The new reduction kind names map 1:1 to arith ops, which makes it easier to tell/look up their semantics. Because both the vector and the gpu dialect depend on the arith dialect, it's more natural to align names with those in arith than with the lowering to llvm intrinsics. Issue: llvm/llvm-project#72354
…75727) The number of vector elements considered 'small' enough to extract is parameterized. This is to avoid going into specialized reduction lowering when a single/couple of arith ops can do. Targets without dedicated reduction intrinsics can use that as an emulation path too. Depends on llvm/llvm-project#75846.
Propagate fast math flags.
Distinguish
minf
/maxf
andminimumf
/maximumf
.Required for future patterns in #75727.