Skip to content

Commit d625ea1

Browse files
kuharunterumarmung
andauthored
[mlir][spirv] Split codegen for float min/max reductions and others v2. [NFC] (llvm#73363)
This is llvm#69023 but with cleanups. Reduced complexity by avoiding CRTP and preprocessor defines in favor of free functions Original description by @unterumarmung: --- This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. There are two types of min/max operations for floating-point numbers: `minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops. ~~However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern.~~ This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations. --------- Co-authored-by: Daniil Dudkin <[email protected]>
1 parent ddc6ef4 commit d625ea1

File tree

1 file changed

+108
-38
lines changed

1 file changed

+108
-38
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 108 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2121
#include "mlir/IR/BuiltinAttributes.h"
2222
#include "mlir/IR/BuiltinTypes.h"
23+
#include "mlir/IR/Location.h"
2324
#include "mlir/IR/Matchers.h"
2425
#include "mlir/IR/PatternMatch.h"
2526
#include "mlir/IR/TypeUtilities.h"
2627
#include "mlir/Support/LogicalResult.h"
2728
#include "mlir/Transforms/DialectConversion.h"
2829
#include "llvm/ADT/ArrayRef.h"
2930
#include "llvm/ADT/STLExtras.h"
31+
#include "llvm/ADT/SmallVector.h"
3032
#include "llvm/ADT/SmallVectorExtras.h"
3133
#include "llvm/Support/FormatVariadic.h"
3234
#include <cassert>
@@ -351,39 +353,64 @@ struct VectorInsertStridedSliceOpConvert final
351353
}
352354
};
353355

354-
template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
355-
class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
356-
struct VectorReductionPattern final
357-
: public OpConversionPattern<vector::ReductionOp> {
356+
static SmallVector<Value> extractAllElements(
357+
vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
358+
VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
359+
int numElements = static_cast<int>(srcVectorType.getDimSize(0));
360+
SmallVector<Value> values;
361+
values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
362+
Location loc = reduceOp.getLoc();
363+
364+
for (int i = 0; i < numElements; ++i) {
365+
values.push_back(rewriter.create<spirv::CompositeExtractOp>(
366+
loc, srcVectorType.getElementType(), adaptor.getVector(),
367+
rewriter.getI32ArrayAttr({i})));
368+
}
369+
if (Value acc = adaptor.getAcc())
370+
values.push_back(acc);
371+
372+
return values;
373+
}
374+
375+
struct ReductionRewriteInfo {
376+
Type resultType;
377+
SmallVector<Value> extractedElements;
378+
};
379+
380+
FailureOr<ReductionRewriteInfo> static getReductionInfo(
381+
vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
382+
ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
383+
Type resultType = typeConverter.convertType(op.getType());
384+
if (!resultType)
385+
return failure();
386+
387+
auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
388+
if (!srcVectorType || srcVectorType.getRank() != 1)
389+
return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
390+
391+
SmallVector<Value> extractedElements =
392+
extractAllElements(op, adaptor, srcVectorType, rewriter);
393+
394+
return ReductionRewriteInfo{resultType, std::move(extractedElements)};
395+
}
396+
397+
template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
398+
typename SPIRVSMinOp>
399+
struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
358400
using OpConversionPattern::OpConversionPattern;
359401

360402
LogicalResult
361403
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
362404
ConversionPatternRewriter &rewriter) const override {
363-
Type resultType = typeConverter->convertType(reduceOp.getType());
364-
if (!resultType)
405+
auto reductionInfo =
406+
getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
407+
if (failed(reductionInfo))
365408
return failure();
366409

367-
auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
368-
if (!srcVectorType || srcVectorType.getRank() != 1)
369-
return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
370-
371-
// Extract all elements.
372-
int numElements = srcVectorType.getDimSize(0);
373-
SmallVector<Value, 4> values;
374-
values.reserve(numElements + (adaptor.getAcc() != nullptr));
375-
Location loc = reduceOp.getLoc();
376-
for (int i = 0; i < numElements; ++i) {
377-
values.push_back(rewriter.create<spirv::CompositeExtractOp>(
378-
loc, srcVectorType.getElementType(), adaptor.getVector(),
379-
rewriter.getI32ArrayAttr({i})));
380-
}
381-
if (Value acc = adaptor.getAcc())
382-
values.push_back(acc);
383-
384-
// Reduce them.
385-
Value result = values.front();
386-
for (Value next : llvm::ArrayRef(values).drop_front()) {
410+
auto [resultType, extractedElements] = *reductionInfo;
411+
Location loc = reduceOp->getLoc();
412+
Value result = extractedElements.front();
413+
for (Value next : llvm::drop_begin(extractedElements)) {
387414
switch (reduceOp.getKind()) {
388415

389416
#define INT_AND_FLOAT_CASE(kind, iop, fop) \
@@ -403,10 +430,6 @@ struct VectorReductionPattern final
403430

404431
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
405432
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
406-
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
407-
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
408-
INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
409-
INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
410433
INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
411434
INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
412435
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,7 +439,51 @@ struct VectorReductionPattern final
416439
case vector::CombiningKind::OR:
417440
case vector::CombiningKind::XOR:
418441
return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
442+
default:
443+
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
419444
}
445+
#undef INT_AND_FLOAT_CASE
446+
#undef INT_OR_FLOAT_CASE
447+
}
448+
449+
rewriter.replaceOp(reduceOp, result);
450+
return success();
451+
}
452+
};
453+
454+
template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
455+
struct VectorReductionFloatMinMax final
456+
: OpConversionPattern<vector::ReductionOp> {
457+
using OpConversionPattern::OpConversionPattern;
458+
459+
LogicalResult
460+
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
461+
ConversionPatternRewriter &rewriter) const override {
462+
auto reductionInfo =
463+
getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
464+
if (failed(reductionInfo))
465+
return failure();
466+
467+
auto [resultType, extractedElements] = *reductionInfo;
468+
Location loc = reduceOp->getLoc();
469+
Value result = extractedElements.front();
470+
for (Value next : llvm::drop_begin(extractedElements)) {
471+
switch (reduceOp.getKind()) {
472+
473+
#define INT_OR_FLOAT_CASE(kind, fop) \
474+
case vector::CombiningKind::kind: \
475+
result = rewriter.create<fop>(loc, resultType, result, next); \
476+
break
477+
478+
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
479+
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
480+
INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
481+
INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
482+
483+
default:
484+
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
485+
}
486+
#undef INT_OR_FLOAT_CASE
420487
}
421488

422489
rewriter.replaceOp(reduceOp, result);
@@ -674,13 +741,14 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
674741
};
675742

676743
} // namespace
677-
#define CL_MAX_MIN_OPS \
678-
spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
679-
spirv::CLSMaxOp, spirv::CLSMinOp
744+
#define CL_INT_MAX_MIN_OPS \
745+
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
746+
747+
#define GL_INT_MAX_MIN_OPS \
748+
spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
680749

681-
#define GL_MAX_MIN_OPS \
682-
spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
683-
spirv::GLSMaxOp, spirv::GLSMinOp
750+
#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
751+
#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
684752

685753
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
686754
RewritePatternSet &patterns) {
@@ -689,8 +757,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
689757
VectorExtractElementOpConvert, VectorExtractOpConvert,
690758
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
691759
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
692-
VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
693-
VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
760+
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
761+
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
762+
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
763+
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
694764
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
695765
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
696766
typeConverter, patterns.getContext());

0 commit comments

Comments
 (0)