|
41 | 41 | #include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
42 | 42 | #include "llvm/ADT/ArrayRef.h"
|
43 | 43 | #include "llvm/ADT/TypeSwitch.h"
|
| 44 | +#include <mlir/IR/ValueRange.h> |
44 | 45 |
|
45 | 46 | namespace fir {
|
46 | 47 | #define GEN_PASS_DEF_FIRTOLLVMLOWERING
|
@@ -3512,42 +3513,87 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
|
3512 | 3513 | }
|
3513 | 3514 | };
|
3514 | 3515 |
|
3515 |
| -/// Inlined complex division |
| 3516 | +static mlir::LogicalResult getDivc3(fir::DivcOp op, |
| 3517 | + mlir::ConversionPatternRewriter &rewriter, |
| 3518 | + std::string funcName, mlir::Type returnType, |
| 3519 | + llvm::SmallVector<mlir::Type> argType, |
| 3520 | + llvm::SmallVector<mlir::Value> args) { |
| 3521 | + auto module = op->getParentOfType<mlir::ModuleOp>(); |
| 3522 | + auto loc = op.getLoc(); |
| 3523 | + if (mlir::LLVM::LLVMFuncOp divideFunc = |
| 3524 | + module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(funcName)) { |
| 3525 | + auto call = rewriter.create<mlir::LLVM::CallOp>( |
| 3526 | + loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args); |
| 3527 | + rewriter.replaceOp(op, call->getResults()); |
| 3528 | + return mlir::success(); |
| 3529 | + } |
| 3530 | + mlir::OpBuilder moduleBuilder( |
| 3531 | + op->getParentOfType<mlir::ModuleOp>().getBodyRegion()); |
| 3532 | + auto divideFunc = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>( |
| 3533 | + rewriter.getUnknownLoc(), funcName, |
| 3534 | + mlir::LLVM::LLVMFunctionType::get(returnType, argType, |
| 3535 | + /*isVarArg=*/false)); |
| 3536 | + auto call = rewriter.create<mlir::LLVM::CallOp>( |
| 3537 | + loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args); |
| 3538 | + rewriter.replaceOp(op, call->getResults()); |
| 3539 | + return mlir::success(); |
| 3540 | +} |
| 3541 | + |
| 3542 | +/// complex division |
3516 | 3543 | struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
|
3517 | 3544 | using FIROpConversion::FIROpConversion;
|
3518 | 3545 |
|
3519 | 3546 | mlir::LogicalResult
|
3520 | 3547 | matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor,
|
3521 | 3548 | mlir::ConversionPatternRewriter &rewriter) const override {
|
3522 |
| - // TODO: Can we use a call to __divdc3 instead? |
3523 |
| - // Just generate inline code for now. |
3524 | 3549 | // given: (x + iy) / (x' + iy')
|
3525 | 3550 | // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
|
3526 | 3551 | mlir::Value a = adaptor.getOperands()[0];
|
3527 | 3552 | mlir::Value b = adaptor.getOperands()[1];
|
3528 | 3553 | auto loc = divc.getLoc();
|
3529 | 3554 | mlir::Type eleTy = convertType(getComplexEleTy(divc.getType()));
|
3530 |
| - mlir::Type ty = convertType(divc.getType()); |
| 3555 | + llvm::SmallVector<mlir::Type> argTy = {eleTy, eleTy, eleTy, eleTy}; |
| 3556 | + mlir::Type firReturnTy = divc.getType(); |
| 3557 | + mlir::Type ty = convertType(firReturnTy); |
3531 | 3558 | auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0);
|
3532 | 3559 | auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
|
3533 | 3560 | auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
|
3534 | 3561 | auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
|
3535 |
| - auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1); |
3536 |
| - auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1); |
3537 |
| - auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1); |
3538 |
| - auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1); |
3539 |
| - auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1); |
3540 |
| - auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1); |
3541 |
| - auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1); |
3542 |
| - auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy); |
3543 |
| - auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy); |
3544 |
| - auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d); |
3545 |
| - auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d); |
3546 |
| - auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty); |
3547 |
| - auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0); |
3548 |
| - auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1); |
3549 |
| - rewriter.replaceOp(divc, r0.getResult()); |
3550 |
| - return mlir::success(); |
| 3562 | + |
| 3563 | + fir::KindTy kind = (firReturnTy.dyn_cast<fir::ComplexType>()).getFKind(); |
| 3564 | + mlir::SmallVector<mlir::Value> args = {x0, y0, x1, y1}; |
| 3565 | + switch (kind) { |
| 3566 | + default: |
| 3567 | + llvm_unreachable("Unsupported complex type"); |
| 3568 | + case 4: |
| 3569 | + return getDivc3(divc, rewriter, "__divsc3", ty, argTy, args); |
| 3570 | + case 8: |
| 3571 | + return getDivc3(divc, rewriter, "__divdc3", ty, argTy, args); |
| 3572 | + case 10: |
| 3573 | + return getDivc3(divc, rewriter, "__divxc3", ty, argTy, args); |
| 3574 | + case 16: |
| 3575 | + return getDivc3(divc, rewriter, "__divtc3", ty, argTy, args); |
| 3576 | + case 3: |
| 3577 | + case 2: |
| 3578 | + // No library function for bfloat or half in compiler_rt, generate |
| 3579 | + // inline instead |
| 3580 | + auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1); |
| 3581 | + auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1); |
| 3582 | + auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1); |
| 3583 | + auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1); |
| 3584 | + auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1); |
| 3585 | + auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1); |
| 3586 | + auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1); |
| 3587 | + auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy); |
| 3588 | + auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy); |
| 3589 | + auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d); |
| 3590 | + auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d); |
| 3591 | + auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty); |
| 3592 | + auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0); |
| 3593 | + auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1); |
| 3594 | + rewriter.replaceOp(divc, r0.getResult()); |
| 3595 | + return mlir::success(); |
| 3596 | + } |
3551 | 3597 | }
|
3552 | 3598 | };
|
3553 | 3599 |
|
|
0 commit comments