Skip to content

Commit 8f5d519

Browse files
unterumarmungdcaballe
authored andcommitted
[mlir][vector] Implement Workaround Lowerings for Masked fm**imum Reductions
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. Within LLVM, there are no masked reduction counterparts for vector reductions such as `fmaximum` and `fminimum`. More information can be found here: #64940 (comment). To address this issue in MLIR, where we need to generate appropriate lowerings for these cases, we employ regular non-masked intrinsics. However, we modify the input vector using the `arith.select` operation to effectively deactivate undesired elements using a "neutral mask value". The neutral mask value is the smallest possible value for the `fmaximum` reduction and the largest possible value for the `fminimum` reduction. Depends on D158618 Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D158773
1 parent 709b274 commit 8f5d519

File tree

2 files changed

+89
-4
lines changed

2 files changed

+89
-4
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1616
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1819
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
1920
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2021
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22+
#include "mlir/IR/BuiltinAttributes.h"
23+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2124
#include "mlir/IR/BuiltinTypes.h"
2225
#include "mlir/IR/TypeUtilities.h"
2326
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
2427
#include "mlir/Transforms/DialectConversion.h"
28+
#include "llvm/ADT/APFloat.h"
2529
#include "llvm/Support/Casting.h"
2630
#include <optional>
2731

@@ -603,6 +607,51 @@ createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
603607
return result;
604608
}
605609

610+
/// Reduction neutral classes for overloading
611+
class MaskNeutralFMaximum {};
612+
class MaskNeutralFMinimum {};
613+
614+
/// Get the mask neutral floating point maximum value
615+
static llvm::APFloat
616+
getMaskNeutralValue(MaskNeutralFMaximum,
617+
const llvm::fltSemantics &floatSemantics) {
618+
return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
619+
}
620+
/// Get the mask neutral floating point minimum value
621+
static llvm::APFloat
622+
getMaskNeutralValue(MaskNeutralFMinimum,
623+
const llvm::fltSemantics &floatSemantics) {
624+
return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
625+
}
626+
627+
/// Create the mask neutral floating point MLIR vector constant
628+
template <typename MaskNeutral>
629+
static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
630+
Location loc, Type llvmType,
631+
Type vectorType) {
632+
const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
633+
auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
634+
auto denseValue =
635+
DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
636+
return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
637+
}
638+
639+
/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
640+
/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
641+
/// `fmaximum`/`fminimum`.
642+
/// More information: https://github.com/llvm/llvm-project/issues/64940
643+
template <class LLVMRedIntrinOp, class MaskNeutral>
644+
static Value lowerMaskedReductionWithRegular(
645+
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
646+
Value vectorOperand, Value accumulator, Value mask) {
647+
const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
648+
rewriter, loc, llvmType, vectorOperand.getType());
649+
const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
650+
loc, mask, vectorOperand, vectorMaskNeutral);
651+
return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
652+
rewriter, loc, llvmType, selectedVectorByMask, accumulator);
653+
}
654+
606655
/// Overloaded methods to lower a reduction to an llvm instrinsic that requires
607656
/// a start value. This start value format spans across fp reductions without
608657
/// mask and all the masked reduction intrinsics.
@@ -903,10 +952,16 @@ class MaskedReductionOpConversion
903952
ReductionNeutralFPMin>(
904953
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
905954
break;
906-
default:
907-
return rewriter.notifyMatchFailure(
908-
maskOp,
909-
"lowering to LLVM is not implemented for this masked operation");
955+
case CombiningKind::MAXIMUMF:
956+
result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
957+
MaskNeutralFMaximum>(
958+
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
959+
break;
960+
case CombiningKind::MINIMUMF:
961+
result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
962+
MaskNeutralFMinimum>(
963+
rewriter, loc, llvmType, operand, acc, maskOp.getMask());
964+
break;
910965
}
911966

912967
// Replace `vector.mask` operation altogether.

mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,36 @@ func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)
101101

102102
// -----
103103

104+
func.func @masked_reduce_maximumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
105+
%0 = vector.mask %mask { vector.reduction <maximumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
106+
return %0 : f32
107+
}
108+
109+
// CHECK-LABEL: func.func @masked_reduce_maximumf_f32(
110+
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
111+
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
112+
// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<-1.401300e-45> : vector<16xf32>) : vector<16xf32>
113+
// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
114+
// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fmaximum(%[[MASKED]]) : (vector<16xf32>) -> f32
115+
// CHECK: return %[[RESULT]]
116+
117+
// -----
118+
119+
func.func @masked_reduce_minimumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
120+
%0 = vector.mask %mask { vector.reduction <minimumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
121+
return %0 : f32
122+
}
123+
124+
// CHECK-LABEL: func.func @masked_reduce_minimumf_f32(
125+
// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>,
126+
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 {
127+
// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<3.40282347E+38> : vector<16xf32>) : vector<16xf32>
128+
// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32>
129+
// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fminimum(%[[MASKED]]) : (vector<16xf32>) -> f32
130+
// CHECK: return %[[RESULT]]
131+
132+
// -----
133+
104134
func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
105135
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
106136
return %0 : i8

0 commit comments

Comments
 (0)