Skip to content

[mlir][spirv] Split codegen for float min/max reductions and others v2. [NFC] #73363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 108 additions & 38 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include <cassert>
Expand Down Expand Up @@ -351,39 +353,64 @@ struct VectorInsertStridedSliceOpConvert final
}
};

template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
struct VectorReductionPattern final
: public OpConversionPattern<vector::ReductionOp> {
static SmallVector<Value> extractAllElements(
vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
int numElements = static_cast<int>(srcVectorType.getDimSize(0));
SmallVector<Value> values;
values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
Location loc = reduceOp.getLoc();

for (int i = 0; i < numElements; ++i) {
values.push_back(rewriter.create<spirv::CompositeExtractOp>(
loc, srcVectorType.getElementType(), adaptor.getVector(),
rewriter.getI32ArrayAttr({i})));
}
if (Value acc = adaptor.getAcc())
values.push_back(acc);

return values;
}

struct ReductionRewriteInfo {
Type resultType;
SmallVector<Value> extractedElements;
};

FailureOr<ReductionRewriteInfo> static getReductionInfo(
vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
Type resultType = typeConverter.convertType(op.getType());
if (!resultType)
return failure();

auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
if (!srcVectorType || srcVectorType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "not a 1-D vector source");

SmallVector<Value> extractedElements =
extractAllElements(op, adaptor, srcVectorType, rewriter);

return ReductionRewriteInfo{resultType, std::move(extractedElements)};
}

template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
typename SPIRVSMinOp>
struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(reduceOp.getType());
if (!resultType)
auto reductionInfo =
getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
if (failed(reductionInfo))
return failure();

auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
if (!srcVectorType || srcVectorType.getRank() != 1)
return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");

// Extract all elements.
int numElements = srcVectorType.getDimSize(0);
SmallVector<Value, 4> values;
values.reserve(numElements + (adaptor.getAcc() != nullptr));
Location loc = reduceOp.getLoc();
for (int i = 0; i < numElements; ++i) {
values.push_back(rewriter.create<spirv::CompositeExtractOp>(
loc, srcVectorType.getElementType(), adaptor.getVector(),
rewriter.getI32ArrayAttr({i})));
}
if (Value acc = adaptor.getAcc())
values.push_back(acc);

// Reduce them.
Value result = values.front();
for (Value next : llvm::ArrayRef(values).drop_front()) {
auto [resultType, extractedElements] = *reductionInfo;
Location loc = reduceOp->getLoc();
Value result = extractedElements.front();
for (Value next : llvm::drop_begin(extractedElements)) {
switch (reduceOp.getKind()) {

#define INT_AND_FLOAT_CASE(kind, iop, fop) \
Expand All @@ -403,10 +430,6 @@ struct VectorReductionPattern final

INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
Expand All @@ -416,7 +439,51 @@ struct VectorReductionPattern final
case vector::CombiningKind::OR:
case vector::CombiningKind::XOR:
return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
default:
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
#undef INT_AND_FLOAT_CASE
#undef INT_OR_FLOAT_CASE
}

rewriter.replaceOp(reduceOp, result);
return success();
}
};

template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
struct VectorReductionFloatMinMax final
: OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto reductionInfo =
getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
if (failed(reductionInfo))
return failure();

auto [resultType, extractedElements] = *reductionInfo;
Location loc = reduceOp->getLoc();
Value result = extractedElements.front();
for (Value next : llvm::drop_begin(extractedElements)) {
switch (reduceOp.getKind()) {

#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: \
result = rewriter.create<fop>(loc, resultType, result, next); \
break

INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);

default:
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
#undef INT_OR_FLOAT_CASE
}

rewriter.replaceOp(reduceOp, result);
Expand Down Expand Up @@ -674,13 +741,14 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
};

} // namespace
#define CL_MAX_MIN_OPS \
spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
spirv::CLSMaxOp, spirv::CLSMinOp
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp

#define GL_INT_MAX_MIN_OPS \
spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp

#define GL_MAX_MIN_OPS \
spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
spirv::GLSMaxOp, spirv::GLSMinOp
#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp

void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
Expand All @@ -689,8 +757,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext());
Expand Down