Skip to content

Commit 1d89f98

Browse files
committed
[mlir][spirv] Fix vector reduction lowerings for FP min/max
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics. This patch addresses tasks 2.4 and 2.5 of the RFC.
1 parent 4a8b0ea commit 1d89f98

File tree

2 files changed

+194
-14
lines changed

2 files changed

+194
-14
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,12 @@ struct VectorReductionPattern final
397397
break
398398

399399
#define INT_OR_FLOAT_CASE(kind, fop) \
400-
case vector::CombiningKind::kind: \
401-
result = rewriter.create<fop>(loc, resultType, result, next); \
402-
break
400+
case vector::CombiningKind::kind: { \
401+
fop op = rewriter.create<fop>(loc, resultType, result, next); \
402+
result = this->generateActionForOp(rewriter, loc, resultType, op, \
403+
vector::CombiningKind::kind); \
404+
break; \
405+
}
403406

404407
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
405408
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
@@ -422,6 +425,51 @@ struct VectorReductionPattern final
422425
rewriter.replaceOp(reduceOp, result);
423426
return success();
424427
}
428+
429+
private:
430+
enum class Action { Nothing, PropagateNaN, PropagateNonNaN };
431+
432+
template <typename Op>
433+
Action getActionForOp(vector::CombiningKind kind) const {
434+
constexpr bool isCLOp = std::is_same_v<Op, spirv::CLFMaxOp> ||
435+
std::is_same_v<Op, spirv::CLFMinOp>;
436+
switch (kind) {
437+
case vector::CombiningKind::MINIMUMF:
438+
case vector::CombiningKind::MAXIMUMF:
439+
return Action::PropagateNaN;
440+
case vector::CombiningKind::MINF:
441+
case vector::CombiningKind::MAXF:
442+
// CL ops already have the same semantic for NaNs as MINF/MAXF
443+
// GL ops have undefined semantics for NaNs, so we need to explicitly
444+
// propagate the non-NaN values
445+
return isCLOp ? Action::Nothing : Action::PropagateNonNaN;
446+
default:
447+
return Action::Nothing;
448+
}
449+
}
450+
451+
template <typename Op>
452+
Value generateActionForOp(ConversionPatternRewriter &rewriter,
453+
mlir::Location loc, Type resultType, Op op,
454+
vector::CombiningKind kind) const {
455+
Action action = getActionForOp<Op>(kind);
456+
457+
if (action == Action::Nothing) {
458+
return op;
459+
}
460+
461+
Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getLhs());
462+
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getRhs());
463+
464+
Value select1 = rewriter.create<spirv::SelectOp>(
465+
loc, resultType, lhsIsNan,
466+
action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op);
467+
Value select2 = rewriter.create<spirv::SelectOp>(
468+
loc, resultType, rhsIsNan,
469+
action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1);
470+
471+
return select2;
472+
}
425473
};
426474

427475
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 143 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
5656
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
5757
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
5858
// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
59-
// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
60-
// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
61-
// CHECK: return %[[MAX2]]
59+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
60+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
61+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
62+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
63+
// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]]
64+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
65+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
66+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
67+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
68+
// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]]
69+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
70+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
71+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
72+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
73+
// CHECK: return %[[SELECT5]]
6274
func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
6375
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
6476
return %reduce : f32
@@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
7082
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
7183
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
7284
// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
85+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
86+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
87+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
88+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
89+
// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]]
90+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
91+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
92+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
93+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
94+
// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]]
95+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
96+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
97+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
98+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
99+
// CHECK: return %[[SELECT5]]
100+
func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
101+
%reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
102+
return %reduce : f32
103+
}
104+
105+
// CHECK-LABEL: func @cl_reduction_maxf
106+
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
107+
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
108+
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
109+
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
110+
// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
111+
// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
112+
// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
113+
// CHECK: return %[[MAX2]]
114+
func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
115+
%reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
116+
return %reduce : f32
117+
}
118+
119+
// CHECK-LABEL: func @cl_reduction_minf
120+
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
121+
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
122+
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
123+
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
124+
// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
73125
// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]]
74126
// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]]
75127
// CHECK: return %[[MIN2]]
76-
func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
77-
%reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
128+
func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
129+
%reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
78130
return %reduce : f32
79131
}
80132

@@ -522,32 +574,112 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
522574
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
523575
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
524576
// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
525-
// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]]
526-
// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]]
527-
// CHECK: return %[[MAX2]]
577+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
578+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
579+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
580+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
581+
// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
582+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
583+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
584+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
585+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
586+
// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
587+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
588+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
589+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
590+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
591+
// CHECK: return %[[SELECT5]]
528592
func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
529593
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
530594
return %reduce : f32
531595
}
532596

533597
// -----
534598

599+
// CHECK-LABEL: func @reduction_maxf
600+
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
601+
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
602+
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
603+
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
604+
// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
605+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
606+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
607+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32
608+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
609+
// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
610+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
611+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
612+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32
613+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
614+
// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
615+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
616+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
617+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32
618+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
619+
// CHECK: return %[[SELECT5]]
620+
func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
621+
%reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
622+
return %reduce : f32
623+
}
624+
625+
// -----
626+
535627
// CHECK-LABEL: func @reduction_minimumf
536628
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
537629
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
538630
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
539631
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
540632
// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
541-
// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[MIN0]], %[[S2]]
542-
// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[MIN1]], %[[S]]
543-
// CHECK: return %[[MIN2]]
633+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
634+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
635+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
636+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
637+
// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
638+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
639+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
640+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
641+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
642+
// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
643+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
644+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
645+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
646+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
647+
// CHECK: return %[[SELECT5]]
544648
func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
545649
%reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
546650
return %reduce : f32
547651
}
548652

549653
// -----
550654

655+
// CHECK-LABEL: func @reduction_minf
656+
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
657+
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
658+
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
659+
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
660+
// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
661+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
662+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
663+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MIN0]] : i1, f32
664+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
665+
// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
666+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
667+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
668+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MIN1]] : i1, f32
669+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
670+
// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
671+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
672+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
673+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MIN2]] : i1, f32
674+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
675+
// CHECK: return %[[SELECT5]]
676+
func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
677+
%reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
678+
return %reduce : f32
679+
}
680+
681+
// -----
682+
551683
// CHECK-LABEL: func @reduction_maxsi
552684
// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
553685
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>

0 commit comments

Comments
 (0)