Skip to content

[Flang] Attempt to fix Nan handling in Minloc/Maxloc intrinsic simplification #82313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,9 +852,8 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, elementType,
llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
return builder.createRealConstant(loc, elementType, limit);
}
unsigned bits = elementType.getIntOrFloatBitWidth();
int64_t limitInt =
Expand Down Expand Up @@ -895,19 +894,30 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
// Set flag that mask was true at some point
mlir::Value flagSet = builder.createIntegerConstant(
loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array},
oneBasedIndices);
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);

// Compare with the max reduction value
mlir::Value cmp;
if (elementType.isa<mlir::FloatType>()) {
// For FP reductions we want the first smallest value to be used, that
// is not NaN. A OGL/OLT condition will usually work for this unless all
// the values are Nan or Inf. This follows the same logic as
// NumericCompare for Minloc/Maxlox in extrema.cpp.
cmp = builder.create<mlir::arith::CmpFOp>(
loc,
isMax ? mlir::arith::CmpFPredicate::OGT
: mlir::arith::CmpFPredicate::OLT,
elem, reduction);

mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
} else if (elementType.isa<mlir::IntegerType>()) {
cmp = builder.create<mlir::arith::CmpIOp>(
loc,
Expand All @@ -918,11 +928,18 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
llvm_unreachable("unsupported type");
}

// The condition used for the loop is isFirst || <the condition above>.
isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
isFirst = builder.create<mlir::arith::XOrIOp>(
loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);

// Set the new coordinate to the result
fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
/*withElseRegion*/ true);

builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
mlir::Type resultElemTy =
hlfir::getFortranElementType(resultArr.getType());
mlir::Type returnRefTy = builder.getRefType(resultElemTy);
Expand Down
59 changes: 20 additions & 39 deletions flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,42 +649,6 @@ void fir::genMinMaxlocReductionLoop(
reductionVal = ifOp.getResult(0);
}
}

// Check for case where array was full of max values.
// flag will be 0 if mask was never true, 1 if mask was true as some point,
// this is needed to avoid catching cases where we didn't access any elements
// e.g. mask=.FALSE.
mlir::Value flagValue =
builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
fir::IfOp ifMaskTrueOp =
builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());

mlir::Value testInit = initVal(builder, loc, elementType);
fir::IfOp ifMinSetOp;
if (elementType.isa<mlir::FloatType>()) {
mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
/*withElseRegion*/ false);
} else {
mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
/*withElseRegion*/ false);
}
builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());

// Load output array with 1s instead of 0s
for (unsigned int i = 0; i < rank; ++i) {
mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
mlir::Value resultElemAddr =
getAddrFn(builder, loc, resultElemType, resultArr, index);
builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
}
builder.setInsertionPointAfter(ifMaskTrueOp);
}

static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
Expand All @@ -697,8 +661,8 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
return builder.createRealConstant(loc, elementType, limit);
}
unsigned bits = elementType.getIntOrFloatBitWidth();
int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
Expand Down Expand Up @@ -770,19 +734,30 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
// Set flag that mask was true at some point
mlir::Value flagSet = builder.createIntegerConstant(
loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
mlir::Type eleRefTy = builder.getRefType(elementType);
mlir::Value addr =
builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);

mlir::Value cmp;
if (elementType.isa<mlir::FloatType>()) {
// For FP reductions we want the first smallest value to be used, that
// is not NaN. A OGL/OLT condition will usually work for this unless all
// the values are Nan or Inf. This follows the same logic as
// NumericCompare for Minloc/Maxlox in extrema.cpp.
cmp = builder.create<mlir::arith::CmpFOp>(
loc,
isMax ? mlir::arith::CmpFPredicate::OGT
: mlir::arith::CmpFPredicate::OLT,
elem, reduction);

mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
} else if (elementType.isa<mlir::IntegerType>()) {
cmp = builder.create<mlir::arith::CmpIOp>(
loc,
Expand All @@ -793,10 +768,16 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
llvm_unreachable("unsupported type");
}

// The condition used for the loop is isFirst || <the condition above>.
isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
isFirst = builder.create<mlir::arith::XOrIOp>(
loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
/*withElseRegion*/ true);

builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
mlir::Type returnRefTy = builder.getRefType(resultElemTy);
mlir::IndexType idxTy = builder.getIndexType();
Expand Down
34 changes: 19 additions & 15 deletions flang/test/HLFIR/maxloc-elemental.fir
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
return
}
// CHECK-LABEL: func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
// CHECK-NEXT: %true = arith.constant true
// CHECK-NEXT: %c-2147483648_i32 = arith.constant -2147483648 : i32
// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
// CHECK-NEXT: %c0 = arith.constant 0 : index
Expand All @@ -45,14 +46,18 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<i32>
// CHECK-NEXT: %[[V17:.*]] = arith.cmpi sge, %[[V16]], %[[V4]] : i32
// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (i32) {
// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
// CHECK-NEXT: %[[ISFIRST:.*]] = fir.load %[[V0]] : !fir.ref<i32>
// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<i32>
// CHECK-NEXT: %[[V21:.*]] = arith.cmpi sgt, %[[V20]], %arg4 : i32
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (i32) {
// CHECK-NEXT: %[[ISFIRSTL:.*]] = fir.convert %[[ISFIRST]] : (i32) -> i1
// CHECK-NEXT: %[[ISFIRSTNOT:.*]] = arith.xori %[[ISFIRSTL]], %true : i1
// CHECK-NEXT: %[[ORCOND:.*]] = arith.ori %[[V21]], %[[ISFIRSTNOT]] : i1
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[ORCOND]] -> (i32) {
// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
Expand All @@ -66,15 +71,6 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
// CHECK-NEXT: }
// CHECK-NEXT: fir.result %[[V18]] : i32
// CHECK-NEXT: }
// CHECK-NEXT: %[[V12:.*]] = fir.load %[[V0]] : !fir.ref<i32>
// CHECK-NEXT: %[[V13:.*]] = arith.cmpi eq, %[[V12]], %c1_i32 : i32
// CHECK-NEXT: fir.if %[[V13]] {
// CHECK-NEXT: %[[V14:.*]] = arith.cmpi eq, %[[V11]], %c-2147483648_i32 : i32
// CHECK-NEXT: fir.if %[[V14]] {
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: fir.store %c1_i32 to %[[V15]] : !fir.ref<i32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: %[[BD:.*]]:3 = fir.box_dims %[[V2]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK-NEXT: fir.do_loop %arg3 = %c1 to %[[BD]]#1 step %c1 unordered {
// CHECK-NEXT: %[[V13:.*]] = hlfir.designate %[[RES]] (%arg3) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
Expand Down Expand Up @@ -110,21 +106,29 @@ func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a
return
}
// CHECK-LABEL: _QPtest_float
// CHECK: %cst = arith.constant -3.40282347E+38 : f32
// CHECK: %cst = arith.constant 0xFF800000 : f32
// CHECK: %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10:.*]] step %c1 iter_args(%arg4 = %cst) -> (f32) {
// CHECK-NEXT: %[[V14:.*]] = arith.addi %arg3, %c1 : index
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V1:.*]]#0 (%[[V14]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<f32>
// CHECK-NEXT: %[[V17:.*]] = arith.cmpf oge, %[[V16]], %[[V4:.*]] : f32
// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (f32) {
// CHECK-NEXT: fir.store %c1_i32 to %[[V0:.*]] : !fir.ref<i32>
// CHECK-NEXT: %[[ISFIRST:.*]] = fir.load %[[V0:.*]] : !fir.ref<i32>
// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %2#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<f32>
// CHECK-NEXT: %[[V21:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath<contract> : f32
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (f32) {
// CHECK-NEXT: %[[NEW_MIN:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath<contract> : f32
// CHECK-NEXT: %[[CONDRED:.*]] = arith.cmpf une, %arg4, %arg4 fastmath<contract> : f32
// CHECK-NEXT: %[[CONDELEM:.*]] = arith.cmpf oeq, %[[V20]], %[[V20]] fastmath<contract> : f32
// CHECK-NEXT: %[[ANDCOND:.*]] = arith.andi %[[CONDRED]], %[[CONDELEM]] : i1
// CHECK-NEXT: %[[NEW_MIN2:.*]] = arith.ori %[[NEW_MIN]], %[[ANDCOND]] : i1
// CHECK-NEXT: %[[ISFIRSTL:.*]] = fir.convert %[[ISFIRST]] : (i32) -> i1
// CHECK-NEXT: %[[ISFIRSTNOT:.*]] = arith.xori %[[ISFIRSTL]], %true : i1
// CHECK-NEXT: %[[ORCOND:.*]] = arith.ori %[[NEW_MIN2]], %[[ISFIRSTNOT]] : i1
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[ORCOND]] -> (f32) {
// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %{{.}} (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
Expand Down
Loading