-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Fix vector reduction lowerings for FP min/max #69053
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Daniil Dudkin (unterumarmung) ChangesThis patch is part of a larger initiative aimed at fixing floating-point This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics. This patch addresses tasks 2.4 and 2.5 of the RFC. Full diff: https://github.com/llvm/llvm-project/pull/69053.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..1d46d9503e9760d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -397,9 +397,12 @@ struct VectorReductionPattern final
break
#define INT_OR_FLOAT_CASE(kind, fop) \
- case vector::CombiningKind::kind: \
- result = rewriter.create<fop>(loc, resultType, result, next); \
- break
+ case vector::CombiningKind::kind: { \
+ fop op = rewriter.create<fop>(loc, resultType, result, next); \
+ result = this->generateActionForOp(rewriter, loc, resultType, op, \
+ vector::CombiningKind::kind); \
+ break; \
+ }
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
@@ -422,6 +425,51 @@ struct VectorReductionPattern final
rewriter.replaceOp(reduceOp, result);
return success();
}
+
+private:
+ enum class Action { Nothing, PropagateNaN, PropagateNonNaN };
+
+ template <typename Op>
+ Action getActionForOp(vector::CombiningKind kind) const {
+ constexpr bool isCLOp = std::is_same_v<Op, spirv::CLFMaxOp> ||
+ std::is_same_v<Op, spirv::CLFMinOp>;
+ switch (kind) {
+ case vector::CombiningKind::MINIMUMF:
+ case vector::CombiningKind::MAXIMUMF:
+ return Action::PropagateNaN;
+ case vector::CombiningKind::MINF:
+ case vector::CombiningKind::MAXF:
+ // CL ops already have the same semantic for NaNs as MINF/MAXF
+ // GL ops have undefined semantics for NaNs, so we need to explicitly
+ // propagate the non-NaN values
+ return isCLOp ? Action::Nothing : Action::PropagateNonNaN;
+ default:
+ return Action::Nothing;
+ }
+ }
+
+ template <typename Op>
+ Value generateActionForOp(ConversionPatternRewriter &rewriter,
+ mlir::Location loc, Type resultType, Op op,
+ vector::CombiningKind kind) const {
+ Action action = getActionForOp<Op>(kind);
+
+ if (action == Action::Nothing) {
+ return op;
+ }
+
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getLhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getRhs());
+
+ Value select1 = rewriter.create<spirv::SelectOp>(
+ loc, resultType, lhsIsNan,
+ action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op);
+ Value select2 = rewriter.create<spirv::SelectOp>(
+ loc, resultType, rhsIsNan,
+ action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1);
+
+ return select2;
+ }
};
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..91836e556147b8d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
-// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
-// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
-// CHECK: return %[[MAX2]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
@@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
+func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
+// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
+// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
+// CHECK: return %[[MAX2]]
+func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_minf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]]
// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]]
// CHECK: return %[[MIN2]]
-func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
- %reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
+func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
}
@@ -522,9 +574,21 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
-// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]]
-// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]]
-// CHECK: return %[[MAX2]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
@@ -532,15 +596,55 @@ func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
// -----
+// CHECK-LABEL: func @reduction_maxf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
+func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_minimumf
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
-// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[MIN0]], %[[S2]]
-// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[MIN1]], %[[S]]
-// CHECK: return %[[MIN2]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
@@ -548,6 +652,34 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
// -----
+// CHECK-LABEL: func @reduction_minf
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
+// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
+// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
+// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MIN0]] : i1, f32
+// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
+// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
+// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
+// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
+// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MIN1]] : i1, f32
+// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
+// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
+// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
+// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
+// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MIN2]] : i1, f32
+// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
+// CHECK: return %[[SELECT5]]
+func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_maxsi
// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
|
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics. This patch addresses tasks 2.4 and 2.5 of the RFC.
4154c1d
to
1d89f98
Compare
This patch is part of a larger initiative aimed at fixing floating-point
max
andmin
operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics.
This patch addresses tasks 2.4 and 2.5 of the RFC.