Skip to content

[mlir][spirv] Fix vector reduction lowerings for FP min/max #69025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 150 additions & 29 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -28,6 +29,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <cassert>
#include <cstdint>
Expand Down Expand Up @@ -351,15 +353,13 @@ struct VectorInsertStridedSliceOpConvert final
}
};

template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
struct VectorReductionPattern final
: public OpConversionPattern<vector::ReductionOp> {
template <typename Derived>
struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ConversionPatternRewriter &rewriter) const final {
Type resultType = typeConverter->convertType(reduceOp.getType());
if (!resultType)
return failure();
Expand All @@ -368,9 +368,22 @@ struct VectorReductionPattern final
if (!srcVectorType || srcVectorType.getRank() != 1)
return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");

// Extract all elements.
SmallVector<Value> extractedElements =
extractAllElements(reduceOp, adaptor, srcVectorType, rewriter);

const auto &self = static_cast<const Derived &>(*this);

return self.reduceExtracted(reduceOp, extractedElements, resultType,
rewriter);
}

private:
SmallVector<Value>
extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor,
VectorType srcVectorType,
ConversionPatternRewriter &rewriter) const {
int numElements = srcVectorType.getDimSize(0);
SmallVector<Value, 4> values;
SmallVector<Value> values;
values.reserve(numElements + (adaptor.getAcc() != nullptr));
Location loc = reduceOp.getLoc();
for (int i = 0; i < numElements; ++i) {
Expand All @@ -381,9 +394,26 @@ struct VectorReductionPattern final
if (Value acc = adaptor.getAcc())
values.push_back(acc);

// Reduce them.
Value result = values.front();
for (Value next : llvm::ArrayRef(values).drop_front()) {
return values;
}
};

#define VECTOR_REDUCTION_BASE \
VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, \
SPIRVSMaxOp, SPIRVSMinOp>>
template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
typename SPIRVSMinOp>
struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
using Base = VECTOR_REDUCTION_BASE;
using Base::Base;

LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
ArrayRef<Value> extractedElements,
Type resultType,
ConversionPatternRewriter &rewriter) const {
mlir::Location loc = reduceOp->getLoc();
Value result = extractedElements.front();
for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
switch (reduceOp.getKind()) {

#define INT_AND_FLOAT_CASE(kind, iop, fop) \
Expand All @@ -403,10 +433,6 @@ struct VectorReductionPattern final

INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
Expand All @@ -416,13 +442,105 @@ struct VectorReductionPattern final
case vector::CombiningKind::OR:
case vector::CombiningKind::XOR:
return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
default:
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
}

rewriter.replaceOp(reduceOp, result);
return success();
}
};
#undef VECTOR_REDUCTION_BASE
#undef INT_AND_FLOAT_CASE
#undef INT_OR_FLOAT_CASE

#define MIN_MAX_PATTERN_BASE \
VectorReductionPatternBase< \
VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
template <class SPIRVFMaxOp, class SPIRVFMinOp>
struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
using Base = MIN_MAX_PATTERN_BASE;
using Base::Base;

LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
ArrayRef<Value> extractedElements,
Type resultType,
ConversionPatternRewriter &rewriter) const {
mlir::Location loc = reduceOp->getLoc();
Value result = extractedElements.front();
for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
switch (reduceOp.getKind()) {

#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: { \
fop op = rewriter.create<fop>(loc, resultType, result, next); \
result = this->generateActionForOp(rewriter, loc, resultType, op, \
vector::CombiningKind::kind); \
break; \
}

INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);

default:
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
}

rewriter.replaceOp(reduceOp, result);
return success();
}

private:
enum class Action { Nothing, PropagateNaN, PropagateNonNaN };

template <typename Op>
Action getActionForOp(vector::CombiningKind kind) const {
constexpr bool isCLOp = std::is_same_v<Op, spirv::CLFMaxOp> ||
std::is_same_v<Op, spirv::CLFMinOp>;
switch (kind) {
case vector::CombiningKind::MINIMUMF:
case vector::CombiningKind::MAXIMUMF:
return Action::PropagateNaN;
case vector::CombiningKind::MINF:
case vector::CombiningKind::MAXF:
// CL ops already have the same semantic for NaNs as MINF/MAXF
// GL ops have undefined semantics for NaNs, so we need to explicitly
// propagate the non-NaN values
return isCLOp ? Action::Nothing : Action::PropagateNonNaN;
default:
llvm_unreachable("Unexpected case for the switch");
}
}

template <typename Op>
Value generateActionForOp(ConversionPatternRewriter &rewriter,
mlir::Location loc, Type resultType, Op op,
vector::CombiningKind kind) const {
Action action = getActionForOp<Op>(kind);

if (action == Action::Nothing) {
return op;
}

Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getLhs());
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getRhs());

Value select1 = rewriter.create<spirv::SelectOp>(
loc, resultType, lhsIsNan,
action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op);
Value select2 = rewriter.create<spirv::SelectOp>(
loc, resultType, rhsIsNan,
action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1);

return select2;
}
};
#undef MIN_MAX_PATTERN_BASE
#undef INT_OR_FLOAT_CASE

class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
public:
Expand Down Expand Up @@ -604,25 +722,28 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
};

} // namespace
#define CL_MAX_MIN_OPS \
spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
spirv::CLSMaxOp, spirv::CLSMinOp
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp

#define GL_INT_MAX_MIN_OPS \
spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp

#define GL_MAX_MIN_OPS \
spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
spirv::GLSMaxOp, spirv::GLSMinOp
#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp

void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert,
VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorSplatPattern>(typeConverter, patterns.getContext());
patterns.add<
VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorSplatPattern>(typeConverter, patterns.getContext());
}

void mlir::populateVectorReductionToSPIRVDotProductPatterns(
Expand Down
Loading