Skip to content

Commit 9264252

Browse files
author
git apple-llvm automerger
committed
Merge commit 'a7bb8e273f43' from llvm.org/main into next
2 parents 3c5ff99 + a7bb8e2 commit 9264252

File tree

2 files changed

+100
-34
lines changed

2 files changed

+100
-34
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
4242
#include "llvm/ADT/ArrayRef.h"
4343
#include "llvm/ADT/TypeSwitch.h"
44+
#include <mlir/IR/ValueRange.h>
4445

4546
namespace fir {
4647
#define GEN_PASS_DEF_FIRTOLLVMLOWERING
@@ -3512,42 +3513,87 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
35123513
}
35133514
};
35143515

3515-
/// Inlined complex division
3516+
static mlir::LogicalResult getDivc3(fir::DivcOp op,
3517+
mlir::ConversionPatternRewriter &rewriter,
3518+
std::string funcName, mlir::Type returnType,
3519+
llvm::SmallVector<mlir::Type> argType,
3520+
llvm::SmallVector<mlir::Value> args) {
3521+
auto module = op->getParentOfType<mlir::ModuleOp>();
3522+
auto loc = op.getLoc();
3523+
if (mlir::LLVM::LLVMFuncOp divideFunc =
3524+
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(funcName)) {
3525+
auto call = rewriter.create<mlir::LLVM::CallOp>(
3526+
loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args);
3527+
rewriter.replaceOp(op, call->getResults());
3528+
return mlir::success();
3529+
}
3530+
mlir::OpBuilder moduleBuilder(
3531+
op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
3532+
auto divideFunc = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
3533+
rewriter.getUnknownLoc(), funcName,
3534+
mlir::LLVM::LLVMFunctionType::get(returnType, argType,
3535+
/*isVarArg=*/false));
3536+
auto call = rewriter.create<mlir::LLVM::CallOp>(
3537+
loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args);
3538+
rewriter.replaceOp(op, call->getResults());
3539+
return mlir::success();
3540+
}
3541+
3542+
/// complex division
35163543
struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
35173544
using FIROpConversion::FIROpConversion;
35183545

35193546
mlir::LogicalResult
35203547
matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor,
35213548
mlir::ConversionPatternRewriter &rewriter) const override {
3522-
// TODO: Can we use a call to __divdc3 instead?
3523-
// Just generate inline code for now.
35243549
// given: (x + iy) / (x' + iy')
35253550
// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
35263551
mlir::Value a = adaptor.getOperands()[0];
35273552
mlir::Value b = adaptor.getOperands()[1];
35283553
auto loc = divc.getLoc();
35293554
mlir::Type eleTy = convertType(getComplexEleTy(divc.getType()));
3530-
mlir::Type ty = convertType(divc.getType());
3555+
llvm::SmallVector<mlir::Type> argTy = {eleTy, eleTy, eleTy, eleTy};
3556+
mlir::Type firReturnTy = divc.getType();
3557+
mlir::Type ty = convertType(firReturnTy);
35313558
auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0);
35323559
auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
35333560
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
35343561
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
3535-
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
3536-
auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
3537-
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
3538-
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
3539-
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
3540-
auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
3541-
auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
3542-
auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
3543-
auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
3544-
auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
3545-
auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
3546-
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
3547-
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
3548-
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
3549-
rewriter.replaceOp(divc, r0.getResult());
3550-
return mlir::success();
3562+
3563+
fir::KindTy kind = (firReturnTy.dyn_cast<fir::ComplexType>()).getFKind();
3564+
mlir::SmallVector<mlir::Value> args = {x0, y0, x1, y1};
3565+
switch (kind) {
3566+
default:
3567+
llvm_unreachable("Unsupported complex type");
3568+
case 4:
3569+
return getDivc3(divc, rewriter, "__divsc3", ty, argTy, args);
3570+
case 8:
3571+
return getDivc3(divc, rewriter, "__divdc3", ty, argTy, args);
3572+
case 10:
3573+
return getDivc3(divc, rewriter, "__divxc3", ty, argTy, args);
3574+
case 16:
3575+
return getDivc3(divc, rewriter, "__divtc3", ty, argTy, args);
3576+
case 3:
3577+
case 2:
3578+
// No library function for bfloat or half in compiler_rt, generate
3579+
// inline instead
3580+
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
3581+
auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
3582+
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
3583+
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
3584+
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
3585+
auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
3586+
auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
3587+
auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
3588+
auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
3589+
auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
3590+
auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
3591+
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
3592+
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
3593+
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
3594+
rewriter.replaceOp(divc, r0.getResult());
3595+
return mlir::success();
3596+
}
35513597
}
35523598
};
35533599

flang/test/Fir/convert-to-llvm.fir

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -586,22 +586,42 @@ func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
586586
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
587587
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
588588
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
589-
// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128
590-
// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : f128
591-
// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128
592-
// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128
593-
// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128
594-
// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : f128
595-
// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : f128
596-
// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128
597-
// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : f128
598-
// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
599-
// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
600-
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
601-
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
602-
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
589+
// CHECK: %[[CALL:.*]] = llvm.call @__divtc3(%[[X0]], %[[Y0]], %[[X1]], %[[Y1]]) : (f128, f128, f128, f128) -> !llvm.struct<(f128, f128)>
603590
// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)>
604591

592+
// -----
593+
594+
// Test FIR complex division inlines for KIND=3
595+
596+
func.func @fir_complex_div(%a: !fir.complex<3>, %b: !fir.complex<3>) -> !fir.complex<3> {
597+
%c = fir.divc %a, %b : !fir.complex<3>
598+
return %c : !fir.complex<3>
599+
}
600+
601+
// CHECK-LABEL: llvm.func @fir_complex_div(
602+
// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(bf16, bf16)>,
603+
// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(bf16, bf16)>) -> !llvm.struct<(bf16, bf16)> {
604+
// CHECK: %[[X0:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.struct<(bf16, bf16)>
605+
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(bf16, bf16)>
606+
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(bf16, bf16)>
607+
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(bf16, bf16)>
608+
// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : bf16
609+
// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : bf16
610+
// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : bf16
611+
// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : bf16
612+
// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : bf16
613+
// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : bf16
614+
// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : bf16
615+
// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : bf16
616+
// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : bf16
617+
// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : bf16
618+
// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : bf16
619+
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(bf16, bf16)>
620+
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(bf16, bf16)>
621+
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(bf16, bf16)>
622+
// CHECK: llvm.return %{{.*}} : !llvm.struct<(bf16, bf16)>
623+
624+
605625
// -----
606626

607627
// Test FIR complex negation conversion

0 commit comments

Comments
 (0)