Skip to content

Commit 754d896

Browse files
authored
[flang] Propagate fast-math flags to FIROpBuilder. (#126316)
One constructor was missing to propagate fast-math flags from an operation to the builder. It is fixed now. And the builder creation in one opt-bufferization case should take the rewriter, I think.
1 parent 1932ed0 commit 754d896

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
5757
explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap,
5858
mlir::SymbolTable *symbolTable = nullptr)
5959
: OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)},
60-
symbolTable{symbolTable} {}
60+
symbolTable{symbolTable} {
61+
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
62+
if (fmi) {
63+
// Set the builder with FastMathFlags attached to the operation.
64+
setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
65+
}
66+
}
6167
explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap,
6268
mlir::SymbolTable *symbolTable = nullptr)
6369
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)},

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -930,9 +930,7 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
930930
llvm_unreachable("unsupported type");
931931
};
932932

933-
fir::KindMapping kindMap =
934-
fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
935-
fir::FirOpBuilder builder{op, kindMap};
933+
fir::FirOpBuilder builder{rewriter, op.getOperation()};
936934

937935
mlir::Value init;
938936
GenBodyFn genBodyFn;

flang/test/HLFIR/maxval-elemental.fir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a
8282
// CHECK-NEXT: %[[V6:.*]] = fir.do_loop %arg1 = %c1 to %c4 step %c1 iter_args(%arg2 = %cst) -> (f32) {
8383
// CHECK-NEXT: %[[V7:.*]] = hlfir.designate %[[V5]] (%arg1) : (!fir.box<!fir.array<4xf32>>, index) -> !fir.ref<f32>
8484
// CHECK-NEXT: %[[V8:.*]] = fir.load %[[V7]] : !fir.ref<f32>
85-
// CHECK-NEXT: %[[V9:.*]] = arith.cmpf ogt, %[[V8]], %arg2 : f32
86-
// CHECK-NEXT: %[[V10:.*]] = arith.cmpf une, %arg2, %arg2 : f32
87-
// CHECK-NEXT: %[[V11:.*]] = arith.cmpf oeq, %[[V8]], %[[V8]] : f32
85+
// CHECK-NEXT: %[[V9:.*]] = arith.cmpf ogt, %[[V8]], %arg2 fastmath<contract> : f32
86+
// CHECK-NEXT: %[[V10:.*]] = arith.cmpf une, %arg2, %arg2 fastmath<contract> : f32
87+
// CHECK-NEXT: %[[V11:.*]] = arith.cmpf oeq, %[[V8]], %[[V8]] fastmath<contract> : f32
8888
// CHECK-NEXT: %[[V12:.*]] = arith.andi %[[V10]], %[[V11]] : i1
8989
// CHECK-NEXT: %[[V13:.*]] = arith.ori %[[V9]], %[[V12]] : i1
9090
// CHECK-NEXT: %[[V14:.*]] = arith.select %[[V13]], %[[V8]], %arg2 : f32

flang/test/HLFIR/minval-elemental.fir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a
8282
// CHECK-NEXT: %[[V6:.*]] = fir.do_loop %arg1 = %c1 to %c4 step %c1 iter_args(%arg2 = %cst) -> (f32) {
8383
// CHECK-NEXT: %[[V7:.*]] = hlfir.designate %[[V5]] (%arg1) : (!fir.box<!fir.array<4xf32>>, index) -> !fir.ref<f32>
8484
// CHECK-NEXT: %[[V8:.*]] = fir.load %[[V7]] : !fir.ref<f32>
85-
// CHECK-NEXT: %[[V9:.*]] = arith.cmpf olt, %[[V8]], %arg2 : f32
86-
// CHECK-NEXT: %[[V10:.*]] = arith.cmpf une, %arg2, %arg2 : f32
87-
// CHECK-NEXT: %[[V11:.*]] = arith.cmpf oeq, %[[V8]], %[[V8]] : f32
85+
// CHECK-NEXT: %[[V9:.*]] = arith.cmpf olt, %[[V8]], %arg2 fastmath<contract> : f32
86+
// CHECK-NEXT: %[[V10:.*]] = arith.cmpf une, %arg2, %arg2 fastmath<contract> : f32
87+
// CHECK-NEXT: %[[V11:.*]] = arith.cmpf oeq, %[[V8]], %[[V8]] fastmath<contract> : f32
8888
// CHECK-NEXT: %[[V12:.*]] = arith.andi %[[V10]], %[[V11]] : i1
8989
// CHECK-NEXT: %[[V13:.*]] = arith.ori %[[V9]], %[[V12]] : i1
9090
// CHECK-NEXT: %[[V14:.*]] = arith.select %[[V13]], %[[V8]], %arg2 : f32

0 commit comments

Comments
 (0)