Skip to content

Commit b91a25e

Browse files
authored
[flang] add nsw to operations in subscripts (#110060)
This patch adds nsw to operations when lowering subscripts. See also the discussion in the following discourse post. https://discourse.llvm.org/t/rfc-add-nsw-flags-to-arithmetic-integer-operations-using-the-option-fno-wrapv/77584/9
1 parent 688bc95 commit b91a25e

File tree

9 files changed

+206
-10
lines changed

9 files changed

+206
-10
lines changed

flang/docs/Extensions.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,12 @@ character j
780780
print *, [(j,j=1,10)]
781781
```
782782

783+
* The Fortran standard doesn't mention integer overflow explicitly. In many cases,
784+
however, integer overflow makes programs non-conforming.
785+
F18 follows other widely-used Fortran compilers. Specifically, f18 assumes
786+
integer overflow never occurs in address calculations and increment of
787+
do-variable unless the option `-fwrapv` is enabled.
788+
783789
## De Facto Standard Features
784790

785791
* `EXTENDS_TYPE_OF()` returns `.TRUE.` if both of its arguments have the

flang/include/flang/Lower/LoweringOptions.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,14 @@ ENUM_LOWERINGOPT(NoPPCNativeVecElemOrder, unsigned, 1, 0)
3434
/// On by default.
3535
ENUM_LOWERINGOPT(Underscoring, unsigned, 1, 1)
3636

37+
/// If true, assume the behavior of integer overflow is defined
38+
/// (i.e. wraps around as two's complement). On by default.
39+
/// TODO: make the default off
40+
ENUM_LOWERINGOPT(IntegerWrapAround, unsigned, 1, 1)
41+
3742
/// If true, add nsw flags to loop variable increments.
3843
/// Off by default.
44+
/// TODO: integrate this option with the above
3945
ENUM_LOWERINGOPT(NSWOnLoopVarInc, unsigned, 1, 0)
4046

4147
#undef LOWERINGOPT

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,16 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
8585
// The listener self-reference has to be updated in case of copy-construction.
8686
FirOpBuilder(const FirOpBuilder &other)
8787
: OpBuilder(other), OpBuilder::Listener(), kindMap{other.kindMap},
88-
fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
88+
fastMathFlags{other.fastMathFlags},
89+
integerOverflowFlags{other.integerOverflowFlags},
90+
symbolTable{other.symbolTable} {
8991
setListener(this);
9092
}
9193

9294
FirOpBuilder(FirOpBuilder &&other)
9395
: OpBuilder(other), OpBuilder::Listener(),
9496
kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags},
97+
integerOverflowFlags{other.integerOverflowFlags},
9598
symbolTable{other.symbolTable} {
9699
setListener(this);
97100
}
@@ -521,6 +524,18 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
521524
return fmfString;
522525
}
523526

527+
/// Set default IntegerOverflowFlags value for all operations
528+
/// supporting mlir::arith::IntegerOverflowFlagsAttr that will be created
529+
/// by this builder.
530+
void setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags flags) {
531+
integerOverflowFlags = flags;
532+
}
533+
534+
/// Get current IntegerOverflowFlags value.
535+
mlir::arith::IntegerOverflowFlags getIntegerOverflowFlags() const {
536+
return integerOverflowFlags;
537+
}
538+
524539
/// Dump the current function. (debug)
525540
LLVM_DUMP_METHOD void dumpFunc();
526541

@@ -547,6 +562,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
547562
/// mlir::arith::FastMathAttr.
548563
mlir::arith::FastMathFlags fastMathFlags{};
549564

565+
/// IntegerOverflowFlags that need to be set for operations that support
566+
/// mlir::arith::IntegerOverflowFlagsAttr.
567+
mlir::arith::IntegerOverflowFlags integerOverflowFlags{};
568+
550569
/// fir::GlobalOp and func::FuncOp symbol table to speed-up
551570
/// lookups.
552571
mlir::SymbolTable *symbolTable = nullptr;

flang/lib/Lower/ConvertCall.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2570,9 +2570,26 @@ genIntrinsicRef(const Fortran::evaluate::SpecificIntrinsic *intrinsic,
25702570
hlfir::Entity{*var}, /*isPresent=*/std::nullopt});
25712571
continue;
25722572
}
2573+
// arguments of bitwise comparison functions may not have nsw flag
2574+
// even if -fno-wrapv is enabled
2575+
mlir::arith::IntegerOverflowFlags iofBackup{};
2576+
auto isBitwiseComparison = [](const std::string intrinsicName) -> bool {
2577+
if (intrinsicName == "bge" || intrinsicName == "bgt" ||
2578+
intrinsicName == "ble" || intrinsicName == "blt")
2579+
return true;
2580+
return false;
2581+
};
2582+
if (isBitwiseComparison(callContext.getProcedureName())) {
2583+
iofBackup = callContext.getBuilder().getIntegerOverflowFlags();
2584+
callContext.getBuilder().setIntegerOverflowFlags(
2585+
mlir::arith::IntegerOverflowFlags::none);
2586+
}
25732587
auto loweredActual = Fortran::lower::convertExprToHLFIR(
25742588
loc, callContext.converter, *expr, callContext.symMap,
25752589
callContext.stmtCtx);
2590+
if (isBitwiseComparison(callContext.getProcedureName()))
2591+
callContext.getBuilder().setIntegerOverflowFlags(iofBackup);
2592+
25762593
std::optional<mlir::Value> isPresent;
25772594
if (argLowering) {
25782595
fir::ArgLoweringRule argRules =

flang/lib/Lower/ConvertExprToHLFIR.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1584,9 +1584,14 @@ class HlfirBuilder {
15841584
auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
15851585
return binaryOp.gen(l, b, op.derived(), leftVal, rightVal);
15861586
};
1587+
auto iofBackup = builder.getIntegerOverflowFlags();
1588+
// nsw is never added to operations on vector subscripts
1589+
// even if -fno-wrapv is enabled.
1590+
builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::none);
15871591
mlir::Value elemental = hlfir::genElementalOp(loc, builder, elementType,
15881592
shape, typeParams, genKernel,
15891593
/*isUnordered=*/true);
1594+
builder.setIntegerOverflowFlags(iofBackup);
15901595
fir::FirOpBuilder *bldr = &builder;
15911596
getStmtCtx().attachCleanup(
15921597
[=]() { bldr->create<hlfir::DestroyOp>(loc, elemental); });
@@ -1899,10 +1904,17 @@ class HlfirBuilder {
18991904
template <typename T>
19001905
hlfir::Entity
19011906
HlfirDesignatorBuilder::genSubscript(const Fortran::evaluate::Expr<T> &expr) {
1907+
fir::FirOpBuilder &builder = getBuilder();
1908+
mlir::arith::IntegerOverflowFlags iofBackup{};
1909+
if (!getConverter().getLoweringOptions().getIntegerWrapAround()) {
1910+
iofBackup = builder.getIntegerOverflowFlags();
1911+
builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw);
1912+
}
19021913
auto loweredExpr =
19031914
HlfirBuilder(getLoc(), getConverter(), getSymMap(), getStmtCtx())
19041915
.gen(expr);
1905-
fir::FirOpBuilder &builder = getBuilder();
1916+
if (!getConverter().getLoweringOptions().getIntegerWrapAround())
1917+
builder.setIntegerOverflowFlags(iofBackup);
19061918
// Skip constant conversions that litters designators and makes generated
19071919
// IR harder to read: directly use index constants for constant subscripts.
19081920
mlir::Type idxTy = builder.getIndexType();

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -768,14 +768,23 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
768768

769769
void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
770770
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
771-
if (!fmi)
772-
return;
773-
// TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
774-
// For now set the attribute by the name.
775-
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
776-
if (fastMathFlags != mlir::arith::FastMathFlags::none)
777-
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
778-
op->getContext(), fastMathFlags));
771+
if (fmi) {
772+
// TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
773+
// For now set the attribute by the name.
774+
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
775+
if (fastMathFlags != mlir::arith::FastMathFlags::none)
776+
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
777+
op->getContext(), fastMathFlags));
778+
}
779+
auto iofi =
780+
mlir::dyn_cast<mlir::arith::ArithIntegerOverflowFlagsInterface>(*op);
781+
if (iofi) {
782+
llvm::StringRef arithIOFAttrName = iofi.getIntegerOverflowAttrName();
783+
if (integerOverflowFlags != mlir::arith::IntegerOverflowFlags::none)
784+
op->setAttr(arithIOFAttrName,
785+
mlir::arith::IntegerOverflowFlagsAttr::get(
786+
op->getContext(), integerOverflowFlags));
787+
}
779788
}
780789

781790
void fir::FirOpBuilder::setFastMathFlags(

flang/test/Lower/nsw.f90

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
! RUN: bbc -emit-fir %s -o - | FileCheck %s
2+
! RUN: bbc -emit-fir -fwrapv %s -o - | FileCheck %s --check-prefix=NO-NSW
3+
4+
! NO-NSW-NOT: overflow<nsw>
5+
6+
subroutine subscript(a, i, j, k)
7+
integer :: a(:,:,:), i, j, k
8+
a(i+1, j-2, k*3) = 5
9+
end subroutine
10+
! CHECK-LABEL: func.func @_QPsubscript(
11+
! CHECK: %[[VAL_4:.*]] = arith.constant 3 : i32
12+
! CHECK: %[[VAL_5:.*]] = arith.constant 2 : i32
13+
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
14+
! CHECK: %[[VAL_9:.*]] = fir.declare %{{.*}}a"} : (!fir.box<!fir.array<?x?x?xi32>>, !fir.dscope) -> !fir.box<!fir.array<?x?x?xi32>>
15+
! CHECK: %[[VAL_10:.*]] = fir.rebox %[[VAL_9]] : (!fir.box<!fir.array<?x?x?xi32>>) -> !fir.box<!fir.array<?x?x?xi32>>
16+
! CHECK: %[[VAL_11:.*]] = fir.declare %{{.*}}i"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
17+
! CHECK: %[[VAL_12:.*]] = fir.declare %{{.*}}j"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
18+
! CHECK: %[[VAL_13:.*]] = fir.declare %{{.*}}k"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
19+
! CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_11]] : !fir.ref<i32>
20+
! CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] overflow<nsw> : i32
21+
! CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (i32) -> i64
22+
! CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_12]] : !fir.ref<i32>
23+
! CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_5]] overflow<nsw> : i32
24+
! CHECK: %[[VAL_19:.*]] = fir.convert %[[VAL_18]] : (i32) -> i64
25+
! CHECK: %[[VAL_20:.*]] = fir.load %[[VAL_13]] : !fir.ref<i32>
26+
! CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_20]], %[[VAL_4]] overflow<nsw> : i32
27+
! CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_21]] : (i32) -> i64
28+
! CHECK: %[[VAL_23:.*]] = fir.array_coor %[[VAL_10]] %[[VAL_16]], %[[VAL_19]], %[[VAL_22]] :
29+
30+
! Test that nsw is never added to arith ops
31+
! on vector subscripts.
32+
subroutine vector_subscript_as_value(x, y, z)
33+
integer :: x(100)
34+
integer(8) :: y(20), z(20)
35+
call bar(x(y+z))
36+
end subroutine
37+
! CHECK-LABEL: func.func @_QPvector_subscript_as_value(
38+
! CHECK-NOT: overflow<nsw>
39+
! CHECK: return
40+
41+
subroutine vector_subscript_lhs(x, vector1, vector2)
42+
integer(8) :: vector1(10), vector2(10)
43+
real :: x(:)
44+
x(vector1+vector2) = 42.
45+
end subroutine
46+
! CHECK-LABEL: func.func @_QPvector_subscript_lhs(
47+
! CHECK-NOT: overflow<nsw>
48+
! CHECK: return
49+
50+
! Test that nsw is never added to arith ops
51+
! on arguments of bitwise comparison intrinsics.
52+
subroutine bitwise_comparison(a, b)
53+
integer :: a, b
54+
print *, bge(a+b, a-b)
55+
print *, bgt(a+b, a-b)
56+
print *, ble(a+b, a-b)
57+
print *, blt(a+b, a-b)
58+
end subroutine
59+
! CHECK-LABEL: func.func @_QPbitwise_comparison(
60+
! CHECK-NOT: overflow<nsw>
61+
! CHECK: return

flang/tools/bbc/bbc.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ static llvm::cl::opt<std::string>
228228
llvm::cl::desc("Override host target triple"),
229229
llvm::cl::init(""));
230230

231+
static llvm::cl::opt<bool> integerWrapAround(
232+
"fwrapv",
233+
llvm::cl::desc("Treat signed integer overflow as two's complement"),
234+
llvm::cl::init(false));
235+
236+
// TODO: integrate this option with the above
231237
static llvm::cl::opt<bool>
232238
setNSW("integer-overflow",
233239
llvm::cl::desc("add nsw flag to internal operations"),
@@ -373,6 +379,7 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
373379
Fortran::lower::LoweringOptions loweringOptions{};
374380
loweringOptions.setNoPPCNativeVecElemOrder(enableNoPPCNativeVecElemOrder);
375381
loweringOptions.setLowerToHighLevelFIR(useHLFIR || emitHLFIR);
382+
loweringOptions.setIntegerWrapAround(integerWrapAround);
376383
loweringOptions.setNSWOnLoopVarInc(setNSW);
377384
std::vector<Fortran::lower::EnvironmentDefault> envDefaults = {};
378385
constexpr const char *tuneCPU = "";

flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,3 +585,62 @@ TEST_F(FIRBuilderTest, genArithFastMath) {
585585
auto op4_fmf = op4_fmi.getFastMathFlagsAttr().getValue();
586586
EXPECT_EQ(op4_fmf, FMF1);
587587
}
588+
589+
TEST_F(FIRBuilderTest, genArithIntegerOverflow) {
590+
auto builder = getBuilder();
591+
auto ctx = builder.getContext();
592+
auto loc = builder.getUnknownLoc();
593+
594+
auto intTy = IntegerType::get(ctx, 32);
595+
auto arg = builder.create<fir::UndefOp>(loc, intTy);
596+
597+
// Test that IntegerOverflowFlags is 'none' by default.
598+
mlir::Operation *op1 = builder.create<mlir::arith::AddIOp>(loc, arg, arg);
599+
auto op1_iofi =
600+
mlir::dyn_cast_or_null<mlir::arith::ArithIntegerOverflowFlagsInterface>(
601+
op1);
602+
EXPECT_TRUE(op1_iofi);
603+
auto op1_ioff = op1_iofi.getOverflowAttr().getValue();
604+
EXPECT_EQ(op1_ioff, arith::IntegerOverflowFlags::none);
605+
606+
// Test that the builder is copied properly.
607+
fir::FirOpBuilder builder_copy(builder);
608+
609+
arith::IntegerOverflowFlags nsw = arith::IntegerOverflowFlags::nsw;
610+
builder.setIntegerOverflowFlags(nsw);
611+
arith::IntegerOverflowFlags nuw = arith::IntegerOverflowFlags::nuw;
612+
builder_copy.setIntegerOverflowFlags(nuw);
613+
614+
// Modifying IntegerOverflowFlags for the copy must not affect the original
615+
// builder.
616+
mlir::Operation *op2 = builder.create<mlir::arith::AddIOp>(loc, arg, arg);
617+
auto op2_iofi =
618+
mlir::dyn_cast_or_null<mlir::arith::ArithIntegerOverflowFlagsInterface>(
619+
op2);
620+
EXPECT_TRUE(op2_iofi);
621+
auto op2_ioff = op2_iofi.getOverflowAttr().getValue();
622+
EXPECT_EQ(op2_ioff, nsw);
623+
624+
// Modifying IntegerOverflowFlags for the original builder must not affect the
625+
// copy.
626+
mlir::Operation *op3 =
627+
builder_copy.create<mlir::arith::AddIOp>(loc, arg, arg);
628+
auto op3_iofi =
629+
mlir::dyn_cast_or_null<mlir::arith::ArithIntegerOverflowFlagsInterface>(
630+
op3);
631+
EXPECT_TRUE(op3_iofi);
632+
auto op3_ioff = op3_iofi.getOverflowAttr().getValue();
633+
EXPECT_EQ(op3_ioff, nuw);
634+
635+
// Test that the builder copy inherits IntegerOverflowFlags from the original.
636+
fir::FirOpBuilder builder_copy2(builder);
637+
638+
mlir::Operation *op4 =
639+
builder_copy2.create<mlir::arith::AddIOp>(loc, arg, arg);
640+
auto op4_iofi =
641+
mlir::dyn_cast_or_null<mlir::arith::ArithIntegerOverflowFlagsInterface>(
642+
op4);
643+
EXPECT_TRUE(op4_iofi);
644+
auto op4_ioff = op4_iofi.getOverflowAttr().getValue();
645+
EXPECT_EQ(op4_ioff, nsw);
646+
}

0 commit comments

Comments
 (0)