Skip to content

Commit fb978f0

Browse files
committed
[mlir][Complex]: Add lowerings for AddOp and SubOp from Complex dialect to
Standard. Differential Revision: https://reviews.llvm.org/D106429
1 parent 80e0bd1 commit fb978f0

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,35 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
7979
}
8080
};
8181

82+
// Default conversion which applies the BinaryStandardOp separately on the real
83+
// and imaginary parts. Can for example be used for complex::AddOp and
84+
// complex::SubOp.
85+
template <typename BinaryComplexOp, typename BinaryStandardOp>
86+
struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
87+
using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
88+
89+
LogicalResult
90+
matchAndRewrite(BinaryComplexOp op, ArrayRef<Value> operands,
91+
ConversionPatternRewriter &rewriter) const override {
92+
typename BinaryComplexOp::Adaptor transformed(operands);
93+
auto type = transformed.lhs().getType().template cast<ComplexType>();
94+
auto elementType = type.getElementType().template cast<FloatType>();
95+
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
96+
97+
Value realLhs = b.create<complex::ReOp>(elementType, transformed.lhs());
98+
Value realRhs = b.create<complex::ReOp>(elementType, transformed.rhs());
99+
Value resultReal =
100+
b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
101+
Value imagLhs = b.create<complex::ImOp>(elementType, transformed.lhs());
102+
Value imagRhs = b.create<complex::ImOp>(elementType, transformed.rhs());
103+
Value resultImag =
104+
b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
105+
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
106+
resultImag);
107+
return success();
108+
}
109+
};
110+
82111
struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
83112
using OpConversionPattern<complex::DivOp>::OpConversionPattern;
84113

@@ -554,6 +583,8 @@ void mlir::populateComplexToStandardConversionPatterns(
554583
AbsOpConversion,
555584
ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
556585
ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
586+
BinaryComplexOpConversion<complex::AddOp, AddFOp>,
587+
BinaryComplexOpConversion<complex::SubOp, SubFOp>,
557588
DivOpConversion,
558589
ExpOpConversion,
559590
LogOpConversion,
@@ -578,12 +609,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
578609
populateComplexToStandardConversionPatterns(patterns);
579610

580611
ConversionTarget target(getContext());
581-
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
582-
complex::ComplexDialect>();
583-
target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
584-
complex::ExpOp, complex::LogOp, complex::Log1pOp,
585-
complex::MulOp, complex::NegOp, complex::NotEqualOp,
586-
complex::SignOp>();
612+
target.addLegalDialect<StandardOpsDialect, math::MathDialect>();
613+
target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
587614
if (failed(applyPartialConversion(function, target, std::move(patterns))))
588615
signalPassFailure();
589616
}

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@ func @complex_abs(%arg: complex<f32>) -> f32 {
1414
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
1515
// CHECK: return %[[NORM]] : f32
1616

17+
// CHECK-LABEL: func @complex_add
18+
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
19+
func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
20+
%add = complex.add %lhs, %rhs: complex<f32>
21+
return %add : complex<f32>
22+
}
23+
// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
24+
// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
25+
// CHECK: %[[RESULT_REAL:.*]] = addf %[[REAL_LHS]], %[[REAL_RHS]] : f32
26+
// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
27+
// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
28+
// CHECK: %[[RESULT_IMAG:.*]] = addf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
29+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
30+
// CHECK: return %[[RESULT]] : complex<f32>
31+
1732
// CHECK-LABEL: func @complex_div
1833
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
1934
func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -366,3 +381,18 @@ func @complex_sign(%arg: complex<f32>) -> complex<f32> {
366381
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
367382
// CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
368383
// CHECK: return %[[RESULT]] : complex<f32>
384+
385+
// CHECK-LABEL: func @complex_sub
386+
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
387+
func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
388+
%sub = complex.sub %lhs, %rhs: complex<f32>
389+
return %sub : complex<f32>
390+
}
391+
// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
392+
// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
393+
// CHECK: %[[RESULT_REAL:.*]] = subf %[[REAL_LHS]], %[[REAL_RHS]] : f32
394+
// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
395+
// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
396+
// CHECK: %[[RESULT_IMAG:.*]] = subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
397+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
398+
// CHECK: return %[[RESULT]] : complex<f32>

0 commit comments

Comments
 (0)