@@ -79,6 +79,35 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
79
79
}
80
80
};
81
81
82
+ // Default conversion which applies the BinaryStandardOp separately on the real
83
+ // and imaginary parts. Can for example be used for complex::AddOp and
84
+ // complex::SubOp.
85
+ template <typename BinaryComplexOp, typename BinaryStandardOp>
86
+ struct BinaryComplexOpConversion : public OpConversionPattern <BinaryComplexOp> {
87
+ using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
88
+
89
+ LogicalResult
90
+ matchAndRewrite (BinaryComplexOp op, ArrayRef<Value> operands,
91
+ ConversionPatternRewriter &rewriter) const override {
92
+ typename BinaryComplexOp::Adaptor transformed (operands);
93
+ auto type = transformed.lhs ().getType ().template cast <ComplexType>();
94
+ auto elementType = type.getElementType ().template cast <FloatType>();
95
+ mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
96
+
97
+ Value realLhs = b.create <complex::ReOp>(elementType, transformed.lhs ());
98
+ Value realRhs = b.create <complex::ReOp>(elementType, transformed.rhs ());
99
+ Value resultReal =
100
+ b.create <BinaryStandardOp>(elementType, realLhs, realRhs);
101
+ Value imagLhs = b.create <complex::ImOp>(elementType, transformed.lhs ());
102
+ Value imagRhs = b.create <complex::ImOp>(elementType, transformed.rhs ());
103
+ Value resultImag =
104
+ b.create <BinaryStandardOp>(elementType, imagLhs, imagRhs);
105
+ rewriter.replaceOpWithNewOp <complex::CreateOp>(op, type, resultReal,
106
+ resultImag);
107
+ return success ();
108
+ }
109
+ };
110
+
82
111
struct DivOpConversion : public OpConversionPattern <complex::DivOp> {
83
112
using OpConversionPattern<complex::DivOp>::OpConversionPattern;
84
113
@@ -554,6 +583,8 @@ void mlir::populateComplexToStandardConversionPatterns(
554
583
AbsOpConversion,
555
584
ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
556
585
ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
586
+ BinaryComplexOpConversion<complex::AddOp, AddFOp>,
587
+ BinaryComplexOpConversion<complex::SubOp, SubFOp>,
557
588
DivOpConversion,
558
589
ExpOpConversion,
559
590
LogOpConversion,
@@ -578,12 +609,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
578
609
populateComplexToStandardConversionPatterns (patterns);
579
610
580
611
ConversionTarget target (getContext ());
581
- target.addLegalDialect <StandardOpsDialect, math::MathDialect,
582
- complex::ComplexDialect>();
583
- target.addIllegalOp <complex::AbsOp, complex::DivOp, complex::EqualOp,
584
- complex::ExpOp, complex::LogOp, complex::Log1pOp,
585
- complex::MulOp, complex::NegOp, complex::NotEqualOp,
586
- complex::SignOp>();
612
+ target.addLegalDialect <StandardOpsDialect, math::MathDialect>();
613
+ target.addLegalOp <complex::CreateOp, complex::ImOp, complex::ReOp>();
587
614
if (failed (applyPartialConversion (function, target, std::move (patterns))))
588
615
signalPassFailure ();
589
616
}
0 commit comments