|
15 | 15 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
16 | 16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
17 | 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h"
|
| 18 | +#include "mlir/Dialect/Vector/IR/VectorOps.h" |
18 | 19 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
|
19 | 20 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
20 | 21 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
| 22 | +#include "mlir/IR/BuiltinAttributes.h" |
| 23 | +#include "mlir/IR/BuiltinTypeInterfaces.h" |
21 | 24 | #include "mlir/IR/BuiltinTypes.h"
|
22 | 25 | #include "mlir/IR/TypeUtilities.h"
|
23 | 26 | #include "mlir/Target/LLVMIR/TypeToLLVM.h"
|
24 | 27 | #include "mlir/Transforms/DialectConversion.h"
|
| 28 | +#include "llvm/ADT/APFloat.h" |
25 | 29 | #include "llvm/Support/Casting.h"
|
26 | 30 | #include <optional>
|
27 | 31 |
|
@@ -603,6 +607,51 @@ createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
|
603 | 607 | return result;
|
604 | 608 | }
|
605 | 609 |
|
| 610 | +/// Reduction neutral classes for overloading |
| 611 | +class MaskNeutralFMaximum {}; |
| 612 | +class MaskNeutralFMinimum {}; |
| 613 | + |
| 614 | +/// Get the mask neutral floating point maximum value |
| 615 | +static llvm::APFloat |
| 616 | +getMaskNeutralValue(MaskNeutralFMaximum, |
| 617 | + const llvm::fltSemantics &floatSemantics) { |
| 618 | + return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true); |
| 619 | +} |
| 620 | +/// Get the mask neutral floating point minimum value |
| 621 | +static llvm::APFloat |
| 622 | +getMaskNeutralValue(MaskNeutralFMinimum, |
| 623 | + const llvm::fltSemantics &floatSemantics) { |
| 624 | + return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false); |
| 625 | +} |
| 626 | + |
| 627 | +/// Create the mask neutral floating point MLIR vector constant |
| 628 | +template <typename MaskNeutral> |
| 629 | +static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, |
| 630 | + Location loc, Type llvmType, |
| 631 | + Type vectorType) { |
| 632 | + const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics(); |
| 633 | + auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); |
| 634 | + auto denseValue = |
| 635 | + DenseElementsAttr::get(vectorType.cast<ShapedType>(), value); |
| 636 | + return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue); |
| 637 | +} |
| 638 | + |
| 639 | +/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked |
| 640 | +/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for |
| 641 | +/// `fmaximum`/`fminimum`. |
| 642 | +/// More information: https://github.com/llvm/llvm-project/issues/64940 |
| 643 | +template <class LLVMRedIntrinOp, class MaskNeutral> |
| 644 | +static Value lowerMaskedReductionWithRegular( |
| 645 | + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, |
| 646 | + Value vectorOperand, Value accumulator, Value mask) { |
| 647 | + const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>( |
| 648 | + rewriter, loc, llvmType, vectorOperand.getType()); |
| 649 | + const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>( |
| 650 | + loc, mask, vectorOperand, vectorMaskNeutral); |
| 651 | + return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>( |
| 652 | + rewriter, loc, llvmType, selectedVectorByMask, accumulator); |
| 653 | +} |
| 654 | + |
606 | 655 | /// Overloaded methods to lower a reduction to an llvm instrinsic that requires
|
607 | 656 | /// a start value. This start value format spans across fp reductions without
|
608 | 657 | /// mask and all the masked reduction intrinsics.
|
@@ -903,10 +952,16 @@ class MaskedReductionOpConversion
|
903 | 952 | ReductionNeutralFPMin>(
|
904 | 953 | rewriter, loc, llvmType, operand, acc, maskOp.getMask());
|
905 | 954 | break;
|
906 |
| - default: |
907 |
| - return rewriter.notifyMatchFailure( |
908 |
| - maskOp, |
909 |
| - "lowering to LLVM is not implemented for this masked operation"); |
| 955 | + case CombiningKind::MAXIMUMF: |
| 956 | + result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum, |
| 957 | + MaskNeutralFMaximum>( |
| 958 | + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); |
| 959 | + break; |
| 960 | + case CombiningKind::MINIMUMF: |
| 961 | + result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum, |
| 962 | + MaskNeutralFMinimum>( |
| 963 | + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); |
| 964 | + break; |
910 | 965 | }
|
911 | 966 |
|
912 | 967 | // Replace `vector.mask` operation altogether.
|
|
0 commit comments