Skip to content

Commit bdacd56

Browse files
authored
[flang][CodeGen] add nsw to address calculations (#74709)
`nsw` is a flag for LLVM arithmetic operations meaning "no signed wrap". If this keyword is present, the result of the operation is a poison value if overflow occurs. Adding this keyword permits LLVM to re-order integer arithmetic more aggressively. In https://discourse.llvm.org/t/rfc-changes-to-fircg-xarray-coor-codegen-to-allow-better-hoisting/75257/16 @vzakhari observed that adding nsw is useful to enable hoisting of address calculations after some loops (or is at least a step in that direction). Classic flang also adds nsw to address calculations.
1 parent 633fe60 commit bdacd56

File tree

6 files changed

+115
-104
lines changed

6 files changed

+115
-104
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2387,6 +2387,9 @@ struct XArrayCoorOpConversion
23872387
const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>();
23882388
TypePair baseBoxTyPair =
23892389
baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{};
2390+
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
2391+
mlir::LLVM::IntegerOverflowFlagsAttr::get(
2392+
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
23902393

23912394
// For each dimension of the array, generate the offset calculation.
23922395
for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset,
@@ -2407,32 +2410,37 @@ struct XArrayCoorOpConversion
24072410
if (normalSlice)
24082411
step = integerCast(loc, rewriter, idxTy, operands[sliceOffset + 2]);
24092412
}
2410-
auto idx = rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, index, lb);
2413+
auto idx = rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, index, lb, nsw);
24112414
mlir::Value diff =
2412-
rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, idx, step);
2415+
rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, idx, step, nsw);
24132416
if (normalSlice) {
24142417
mlir::Value sliceLb =
24152418
integerCast(loc, rewriter, idxTy, operands[sliceOffset]);
2416-
auto adj = rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, sliceLb, lb);
2417-
diff = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, diff, adj);
2419+
auto adj =
2420+
rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, sliceLb, lb, nsw);
2421+
diff = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, diff, adj, nsw);
24182422
}
24192423
// Update the offset given the stride and the zero based index `diff`
24202424
// that was just computed.
24212425
if (baseIsBoxed) {
24222426
// Use stride in bytes from the descriptor.
24232427
mlir::Value stride =
24242428
getStrideFromBox(loc, baseBoxTyPair, operands[0], i, rewriter);
2425-
auto sc = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, stride);
2426-
offset = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, offset);
2429+
auto sc =
2430+
rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, stride, nsw);
2431+
offset =
2432+
rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, offset, nsw);
24272433
} else {
24282434
// Use stride computed at last iteration.
2429-
auto sc = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, prevExt);
2430-
offset = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, offset);
2435+
auto sc =
2436+
rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, prevExt, nsw);
2437+
offset =
2438+
rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, offset, nsw);
24312439
// Compute next stride assuming contiguity of the base array
24322440
// (in element number).
24332441
auto nextExt = integerCast(loc, rewriter, idxTy, operands[shapeOffset]);
2434-
prevExt =
2435-
rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, prevExt, nextExt);
2442+
prevExt = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, prevExt,
2443+
nextExt, nsw);
24362444
}
24372445
}
24382446

@@ -2491,8 +2499,8 @@ struct XArrayCoorOpConversion
24912499
assert(coor.getLenParams().size() == 1);
24922500
auto length = integerCast(loc, rewriter, idxTy,
24932501
operands[coor.lenParamsOffset()]);
2494-
offset =
2495-
rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, offset, length);
2502+
offset = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, offset,
2503+
length, nsw);
24962504
} else {
24972505
TODO(loc, "compute size of derived type with type parameters");
24982506
}
@@ -2665,6 +2673,9 @@ struct CoordinateOpConversion
26652673
auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
26662674
mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
26672675
mlir::Type byteTy = ::getI8Type(coor.getContext());
2676+
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
2677+
mlir::LLVM::IntegerOverflowFlagsAttr::get(
2678+
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
26682679

26692680
for (unsigned i = 1, last = operands.size(); i < last; ++i) {
26702681
if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
@@ -2680,9 +2691,9 @@ struct CoordinateOpConversion
26802691
index < lastIndex; ++index) {
26812692
mlir::Value stride = getStrideFromBox(loc, boxTyPair, operands[0],
26822693
index - i, rewriter);
2683-
auto sc = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy,
2684-
operands[index], stride);
2685-
off = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, off);
2694+
auto sc = rewriter.create<mlir::LLVM::MulOp>(
2695+
loc, idxTy, operands[index], stride, nsw);
2696+
off = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, off, nsw);
26862697
}
26872698
resultAddr = rewriter.create<mlir::LLVM::GEPOp>(
26882699
loc, llvmPtrTy, byteTy, resultAddr,

flang/test/Fir/array-coor.fir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ func.func @array_coor_box_value(%29 : !fir.box<!fir.array<2xf64>>,
99
}
1010

1111
// CHECK-LABEL: define double @array_coor_box_value
12-
// CHECK: %[[t3:.*]] = sub i64 %{{.*}}, 1
13-
// CHECK: %[[t4:.*]] = mul i64 %[[t3]], 1
12+
// CHECK: %[[t3:.*]] = sub nsw i64 %{{.*}}, 1
13+
// CHECK: %[[t4:.*]] = mul nsw i64 %[[t3]], 1
1414
// CHECK: %[[t5:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %{{.*}}, i32 0, i32 7, i32 0, i32 2
1515
// CHECK: %[[t6:.*]] = load i64, ptr %[[t5]]
16-
// CHECK: %[[t7:.*]] = mul i64 %[[t4]], %[[t6]]
17-
// CHECK: %[[t8:.*]] = add i64 %[[t7]], 0
16+
// CHECK: %[[t7:.*]] = mul nsw i64 %[[t4]], %[[t6]]
17+
// CHECK: %[[t8:.*]] = add nsw i64 %[[t7]], 0
1818
// CHECK: %[[t9:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %{{.*}}, i32 0, i32 0
1919
// CHECK: %[[t10:.*]] = load ptr, ptr %[[t9]]
2020
// CHECK: %[[t11:.*]] = getelementptr i8, ptr %[[t10]], i64 %[[t8]]
@@ -36,8 +36,8 @@ func.func private @take_int(%arg0: !fir.ref<i32>) -> ()
3636
// CHECK-SAME: ptr %[[VAL_0:.*]])
3737
// CHECK: %[[VAL_1:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]], ptr, [1 x i64] }, ptr %[[VAL_0]], i32 0, i32 7, i32 0, i32 2
3838
// CHECK: %[[VAL_2:.*]] = load i64, ptr %[[VAL_1]]
39-
// CHECK: %[[VAL_3:.*]] = mul i64 1, %[[VAL_2]]
40-
// CHECK: %[[VAL_4:.*]] = add i64 %[[VAL_3]], 0
39+
// CHECK: %[[VAL_3:.*]] = mul nsw i64 1, %[[VAL_2]]
40+
// CHECK: %[[VAL_4:.*]] = add nsw i64 %[[VAL_3]], 0
4141
// CHECK: %[[VAL_5:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]], ptr, [1 x i64] }, ptr %[[VAL_0]], i32 0, i32 0
4242
// CHECK: %[[VAL_6:.*]] = load ptr, ptr %[[VAL_5]]
4343
// CHECK: %[[VAL_7:.*]] = getelementptr i8, ptr %[[VAL_6]], i64 %[[VAL_4]]

flang/test/Fir/arrexp.fir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ func.func @f5(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: !fir.box<!fir.array<?xf
114114
%4 = fir.do_loop %arg3 = %c0 to %1 step %c1 iter_args(%arg4 = %2) -> (!fir.array<?xf32>) {
115115
// CHECK: %[[B_STRIDE_GEP:.*]] = getelementptr {{.*}}, ptr %[[B]], i32 0, i32 7, i32 0, i32 2
116116
// CHECK: %[[B_STRIDE:.*]] = load i64, ptr %[[B_STRIDE_GEP]]
117-
// CHECK: %[[B_DIM_OFFSET:.*]] = mul i64 %{{.*}}, %[[B_STRIDE]]
118-
// CHECK: %[[B_OFFSET:.*]] = add i64 %[[B_DIM_OFFSET]], 0
117+
// CHECK: %[[B_DIM_OFFSET:.*]] = mul nsw i64 %{{.*}}, %[[B_STRIDE]]
118+
// CHECK: %[[B_OFFSET:.*]] = add nsw i64 %[[B_DIM_OFFSET]], 0
119119
// CHECK: %[[B_BASE_GEP:.*]] = getelementptr {{.*}}, ptr %{{.*}}, i32 0, i32 0
120120
// CHECK: %[[B_BASE:.*]] = load ptr, ptr %[[B_BASE_GEP]]
121121
// CHECK: %[[B_VOID_ADDR:.*]] = getelementptr i8, ptr %[[B_BASE]], i64 %[[B_OFFSET]]
@@ -172,7 +172,7 @@ func.func @f7(%arg0: !fir.ref<f32>, %arg1: !fir.box<!fir.array<?xf32>>) {
172172
%0 = fir.shift %c4 : (index) -> !fir.shift<1>
173173
// CHECK: %[[STRIDE_GEP:.*]] = getelementptr {{.*}}, ptr %[[Y]], i32 0, i32 7, i32 0, i32 2
174174
// CHECK: %[[STRIDE:.*]] = load i64, ptr %[[STRIDE_GEP]]
175-
// CHECK: mul i64 96, %[[STRIDE]]
175+
// CHECK: mul nsw i64 96, %[[STRIDE]]
176176
%1 = fir.array_coor %arg1(%0) %c100 : (!fir.box<!fir.array<?xf32>>, !fir.shift<1>, index) -> !fir.ref<f32>
177177
%2 = fir.load %1 : !fir.ref<f32>
178178
fir.store %2 to %arg0 : !fir.ref<f32>
@@ -202,7 +202,7 @@ func.func @f8(%a : !fir.ref<!fir.array<2x2x!fir.type<t{i:i32}>>>, %i : i32) {
202202
func.func @f9(%i: i32, %e : i64, %j: i64, %c: !fir.ref<!fir.array<?x?x!fir.char<1,?>>>) -> !fir.ref<!fir.char<1,?>> {
203203
%s = fir.shape %e, %e : (i64, i64) -> !fir.shape<2>
204204
// CHECK: %[[CAST:.*]] = sext i32 %[[I]] to i64
205-
// CHECK: %[[OFFSET:.*]] = mul i64 %{{.*}}, %[[CAST]]
205+
// CHECK: %[[OFFSET:.*]] = mul nsw i64 %{{.*}}, %[[CAST]]
206206
// CHECK: getelementptr i8, ptr %[[C]], i64 %[[OFFSET]]
207207
%a = fir.array_coor %c(%s) %j, %j typeparams %i : (!fir.ref<!fir.array<?x?x!fir.char<1,?>>>, !fir.shape<2>, i64, i64, i32) -> !fir.ref<!fir.char<1,?>>
208208
return %a : !fir.ref<!fir.char<1,?>>

0 commit comments

Comments
 (0)