Skip to content

Commit f90e547

Browse files
LewuatheZijunZhaoCCK
authored andcommitted
[mlir][complex] Support fastmath in the binary op conversion. (llvm#65702)
Complex dialect arithmetic operations are now able to recognize the given fastmath flags. This PR lets the conversion from complex to standard keep the fastmath flag passed to arith dialect ops. See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
1 parent d0b9501 commit f90e547

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,16 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
137137
auto type = cast<ComplexType>(adaptor.getLhs().getType());
138138
auto elementType = cast<FloatType>(type.getElementType());
139139
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
140+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
140141

141142
Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
142143
Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
143-
Value resultReal =
144-
b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
144+
Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
145+
fmf.getValue());
145146
Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
146147
Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
147-
Value resultImag =
148-
b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
148+
Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
149+
fmf.getValue());
149150
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
150151
resultImag);
151152
return success();

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,37 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
723723
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
724724
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
725725
// CHECK: return %[[NORM]] : f32
726+
727+
// -----
728+
729+
// CHECK-LABEL: func @complex_add_with_fmf
730+
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
731+
func.func @complex_add_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
732+
%add = complex.add %lhs, %rhs fastmath<nnan,contract> : complex<f32>
733+
return %add : complex<f32>
734+
}
735+
// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
736+
// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
737+
// CHECK: %[[RESULT_REAL:.*]] = arith.addf %[[REAL_LHS]], %[[REAL_RHS]] fastmath<nnan,contract> : f32
738+
// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
739+
// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
740+
// CHECK: %[[RESULT_IMAG:.*]] = arith.addf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
741+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
742+
// CHECK: return %[[RESULT]] : complex<f32>
743+
744+
// -----
745+
746+
// CHECK-LABEL: func @complex_sub_with_fmf
747+
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
748+
func.func @complex_sub_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
749+
%sub = complex.sub %lhs, %rhs fastmath<nnan,contract> : complex<f32>
750+
return %sub : complex<f32>
751+
}
752+
// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
753+
// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
754+
// CHECK: %[[RESULT_REAL:.*]] = arith.subf %[[REAL_LHS]], %[[REAL_RHS]] fastmath<nnan,contract> : f32
755+
// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
756+
// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
757+
// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
758+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
759+
// CHECK: return %[[RESULT]] : complex<f32>

0 commit comments

Comments
 (0)