Skip to content

Commit 7242896

Browse files
authored
[Flang] Attempt to fix Nan handling in Minloc/Maxloc intrinsic simplification (#82313)
In certain case "extreme" values like Nan, Inf and 0xffffffff could lead to generating different code via the inline-generated intrinsics vs the versions in the runtimes (and other compilers like gfortran). There are some examples I was using for testing in https://godbolt.org/z/x4EfqEss5. This changes the generation for the intrinsics to be more like the runtimes, using a condition that is similar to: isFirst || (prev != prev && elem == elem) || elem < prev The middle part is only used for floating point operations, and checks if the values are Nan. This should then hopefully make the logic closer to - return the first element with the lowest value, with Nans ignored unless there are only Nans. The initial limit value for floats are also changed from the largest float to Inf, to make sure it is handled correctly. The integer reductions are also changed to use a similar scheme to make sure they work with masked values. This means that the preamble after the loop can be removed.
1 parent 1ff1e82 commit 7242896

File tree

5 files changed

+136
-166
lines changed

5 files changed

+136
-166
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -852,9 +852,8 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
852852
mlir::Type elementType) {
853853
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
854854
const llvm::fltSemantics &sem = ty.getFloatSemantics();
855-
return builder.createRealConstant(
856-
loc, elementType,
857-
llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
855+
llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
856+
return builder.createRealConstant(loc, elementType, limit);
858857
}
859858
unsigned bits = elementType.getIntOrFloatBitWidth();
860859
int64_t limitInt =
@@ -895,19 +894,30 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
895894
// Set flag that mask was true at some point
896895
mlir::Value flagSet = builder.createIntegerConstant(
897896
loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
898-
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
897+
mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
899898
mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array},
900899
oneBasedIndices);
901900
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
902901

903902
// Compare with the max reduction value
904903
mlir::Value cmp;
905904
if (elementType.isa<mlir::FloatType>()) {
905+
// For FP reductions we want the first smallest value to be used, that
906+
// is not NaN. A OGL/OLT condition will usually work for this unless all
907+
// the values are Nan or Inf. This follows the same logic as
908+
// NumericCompare for Minloc/Maxlox in extrema.cpp.
906909
cmp = builder.create<mlir::arith::CmpFOp>(
907910
loc,
908911
isMax ? mlir::arith::CmpFPredicate::OGT
909912
: mlir::arith::CmpFPredicate::OLT,
910913
elem, reduction);
914+
915+
mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
916+
loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
917+
mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
918+
loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
919+
cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
920+
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
911921
} else if (elementType.isa<mlir::IntegerType>()) {
912922
cmp = builder.create<mlir::arith::CmpIOp>(
913923
loc,
@@ -918,11 +928,18 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
918928
llvm_unreachable("unsupported type");
919929
}
920930

931+
// The condition used for the loop is isFirst || <the condition above>.
932+
isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
933+
isFirst = builder.create<mlir::arith::XOrIOp>(
934+
loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
935+
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
936+
921937
// Set the new coordinate to the result
922938
fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
923939
/*withElseRegion*/ true);
924940

925941
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
942+
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
926943
mlir::Type resultElemTy =
927944
hlfir::getFortranElementType(resultArr.getType());
928945
mlir::Type returnRefTy = builder.getRefType(resultElemTy);

flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -649,42 +649,6 @@ void fir::genMinMaxlocReductionLoop(
649649
reductionVal = ifOp.getResult(0);
650650
}
651651
}
652-
653-
// Check for case where array was full of max values.
654-
// flag will be 0 if mask was never true, 1 if mask was true as some point,
655-
// this is needed to avoid catching cases where we didn't access any elements
656-
// e.g. mask=.FALSE.
657-
mlir::Value flagValue =
658-
builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
659-
mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
660-
loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
661-
fir::IfOp ifMaskTrueOp =
662-
builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
663-
builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());
664-
665-
mlir::Value testInit = initVal(builder, loc, elementType);
666-
fir::IfOp ifMinSetOp;
667-
if (elementType.isa<mlir::FloatType>()) {
668-
mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
669-
loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
670-
ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
671-
/*withElseRegion*/ false);
672-
} else {
673-
mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
674-
loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
675-
ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
676-
/*withElseRegion*/ false);
677-
}
678-
builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());
679-
680-
// Load output array with 1s instead of 0s
681-
for (unsigned int i = 0; i < rank; ++i) {
682-
mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
683-
mlir::Value resultElemAddr =
684-
getAddrFn(builder, loc, resultElemType, resultArr, index);
685-
builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
686-
}
687-
builder.setInsertionPointAfter(ifMaskTrueOp);
688652
}
689653

690654
static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
@@ -697,8 +661,8 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
697661
mlir::Type elementType) {
698662
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
699663
const llvm::fltSemantics &sem = ty.getFloatSemantics();
700-
return builder.createRealConstant(
701-
loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
664+
llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
665+
return builder.createRealConstant(loc, elementType, limit);
702666
}
703667
unsigned bits = elementType.getIntOrFloatBitWidth();
704668
int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
@@ -770,19 +734,30 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
770734
// Set flag that mask was true at some point
771735
mlir::Value flagSet = builder.createIntegerConstant(
772736
loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
773-
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
737+
mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
774738
mlir::Type eleRefTy = builder.getRefType(elementType);
775739
mlir::Value addr =
776740
builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
777741
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
778742

779743
mlir::Value cmp;
780744
if (elementType.isa<mlir::FloatType>()) {
745+
// For FP reductions we want the first smallest value to be used, that
746+
// is not NaN. A OGL/OLT condition will usually work for this unless all
747+
// the values are Nan or Inf. This follows the same logic as
748+
// NumericCompare for Minloc/Maxlox in extrema.cpp.
781749
cmp = builder.create<mlir::arith::CmpFOp>(
782750
loc,
783751
isMax ? mlir::arith::CmpFPredicate::OGT
784752
: mlir::arith::CmpFPredicate::OLT,
785753
elem, reduction);
754+
755+
mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
756+
loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
757+
mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
758+
loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
759+
cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
760+
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
786761
} else if (elementType.isa<mlir::IntegerType>()) {
787762
cmp = builder.create<mlir::arith::CmpIOp>(
788763
loc,
@@ -793,10 +768,16 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
793768
llvm_unreachable("unsupported type");
794769
}
795770

771+
// The condition used for the loop is isFirst || <the condition above>.
772+
isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
773+
isFirst = builder.create<mlir::arith::XOrIOp>(
774+
loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
775+
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
796776
fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
797777
/*withElseRegion*/ true);
798778

799779
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
780+
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
800781
mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
801782
mlir::Type returnRefTy = builder.getRefType(resultElemTy);
802783
mlir::IndexType idxTy = builder.getIndexType();

flang/test/HLFIR/maxloc-elemental.fir

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
2323
return
2424
}
2525
// CHECK-LABEL: func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
26+
// CHECK-NEXT: %true = arith.constant true
2627
// CHECK-NEXT: %c-2147483648_i32 = arith.constant -2147483648 : i32
2728
// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
2829
// CHECK-NEXT: %c0 = arith.constant 0 : index
@@ -45,14 +46,18 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
4546
// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<i32>
4647
// CHECK-NEXT: %[[V17:.*]] = arith.cmpi sge, %[[V16]], %[[V4]] : i32
4748
// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (i32) {
48-
// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
49+
// CHECK-NEXT: %[[ISFIRST:.*]] = fir.load %[[V0]] : !fir.ref<i32>
4950
// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
5051
// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
5152
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
5253
// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
5354
// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<i32>
5455
// CHECK-NEXT: %[[V21:.*]] = arith.cmpi sgt, %[[V20]], %arg4 : i32
55-
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (i32) {
56+
// CHECK-NEXT: %[[ISFIRSTL:.*]] = fir.convert %[[ISFIRST]] : (i32) -> i1
57+
// CHECK-NEXT: %[[ISFIRSTNOT:.*]] = arith.xori %[[ISFIRSTL]], %true : i1
58+
// CHECK-NEXT: %[[ORCOND:.*]] = arith.ori %[[V21]], %[[ISFIRSTNOT]] : i1
59+
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[ORCOND]] -> (i32) {
60+
// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
5661
// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
5762
// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
5863
// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
@@ -66,15 +71,6 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
6671
// CHECK-NEXT: }
6772
// CHECK-NEXT: fir.result %[[V18]] : i32
6873
// CHECK-NEXT: }
69-
// CHECK-NEXT: %[[V12:.*]] = fir.load %[[V0]] : !fir.ref<i32>
70-
// CHECK-NEXT: %[[V13:.*]] = arith.cmpi eq, %[[V12]], %c1_i32 : i32
71-
// CHECK-NEXT: fir.if %[[V13]] {
72-
// CHECK-NEXT: %[[V14:.*]] = arith.cmpi eq, %[[V11]], %c-2147483648_i32 : i32
73-
// CHECK-NEXT: fir.if %[[V14]] {
74-
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
75-
// CHECK-NEXT: fir.store %c1_i32 to %[[V15]] : !fir.ref<i32>
76-
// CHECK-NEXT: }
77-
// CHECK-NEXT: }
7874
// CHECK-NEXT: %[[BD:.*]]:3 = fir.box_dims %[[V2]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
7975
// CHECK-NEXT: fir.do_loop %arg3 = %c1 to %[[BD]]#1 step %c1 unordered {
8076
// CHECK-NEXT: %[[V13:.*]] = hlfir.designate %[[RES]] (%arg3) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
@@ -110,21 +106,29 @@ func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a
110106
return
111107
}
112108
// CHECK-LABEL: _QPtest_float
113-
// CHECK: %cst = arith.constant -3.40282347E+38 : f32
109+
// CHECK: %cst = arith.constant 0xFF800000 : f32
114110
// CHECK: %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10:.*]] step %c1 iter_args(%arg4 = %cst) -> (f32) {
115111
// CHECK-NEXT: %[[V14:.*]] = arith.addi %arg3, %c1 : index
116112
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V1:.*]]#0 (%[[V14]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
117113
// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<f32>
118114
// CHECK-NEXT: %[[V17:.*]] = arith.cmpf oge, %[[V16]], %[[V4:.*]] : f32
119115
// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (f32) {
120-
// CHECK-NEXT: fir.store %c1_i32 to %[[V0:.*]] : !fir.ref<i32>
116+
// CHECK-NEXT: %[[ISFIRST:.*]] = fir.load %[[V0:.*]] : !fir.ref<i32>
121117
// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %2#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
122118
// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
123119
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
124120
// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
125121
// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<f32>
126-
// CHECK-NEXT: %[[V21:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath<contract> : f32
127-
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (f32) {
122+
// CHECK-NEXT: %[[NEW_MIN:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath<contract> : f32
123+
// CHECK-NEXT: %[[CONDRED:.*]] = arith.cmpf une, %arg4, %arg4 fastmath<contract> : f32
124+
// CHECK-NEXT: %[[CONDELEM:.*]] = arith.cmpf oeq, %[[V20]], %[[V20]] fastmath<contract> : f32
125+
// CHECK-NEXT: %[[ANDCOND:.*]] = arith.andi %[[CONDRED]], %[[CONDELEM]] : i1
126+
// CHECK-NEXT: %[[NEW_MIN2:.*]] = arith.ori %[[NEW_MIN]], %[[ANDCOND]] : i1
127+
// CHECK-NEXT: %[[ISFIRSTL:.*]] = fir.convert %[[ISFIRST]] : (i32) -> i1
128+
// CHECK-NEXT: %[[ISFIRSTNOT:.*]] = arith.xori %[[ISFIRSTL]], %true : i1
129+
// CHECK-NEXT: %[[ORCOND:.*]] = arith.ori %[[NEW_MIN2]], %[[ISFIRSTNOT]] : i1
130+
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[ORCOND]] -> (f32) {
131+
// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
128132
// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %{{.}} (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
129133
// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
130134
// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>

0 commit comments

Comments
 (0)