Skip to content

Commit a736e6b

Browse files
committed
[mlir][spirv] Split codegen for float min/max reductions and others (NFC)
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. There are two types of min/max operations for floating-point numbers: `minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops. However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.
1 parent 475e154 commit a736e6b

File tree

1 file changed

+101
-29
lines changed

1 file changed

+101
-29
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 101 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2121
#include "mlir/IR/BuiltinAttributes.h"
2222
#include "mlir/IR/BuiltinTypes.h"
23+
#include "mlir/IR/Location.h"
2324
#include "mlir/IR/Matchers.h"
2425
#include "mlir/IR/PatternMatch.h"
2526
#include "mlir/IR/TypeUtilities.h"
@@ -351,15 +352,13 @@ struct VectorInsertStridedSliceOpConvert final
351352
}
352353
};
353354

354-
template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
355-
class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
356-
struct VectorReductionPattern final
357-
: public OpConversionPattern<vector::ReductionOp> {
355+
template <typename Derived>
356+
struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
358357
using OpConversionPattern::OpConversionPattern;
359358

360359
LogicalResult
361360
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
362-
ConversionPatternRewriter &rewriter) const override {
361+
ConversionPatternRewriter &rewriter) const final {
363362
Type resultType = typeConverter->convertType(reduceOp.getType());
364363
if (!resultType)
365364
return failure();
@@ -368,9 +367,22 @@ struct VectorReductionPattern final
368367
if (!srcVectorType || srcVectorType.getRank() != 1)
369368
return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
370369

371-
// Extract all elements.
370+
SmallVector<Value> extractedElements =
371+
extractAllElements(reduceOp, adaptor, srcVectorType, rewriter);
372+
373+
const auto &self = static_cast<const Derived &>(*this);
374+
375+
return self.reduceExtracted(reduceOp, extractedElements, resultType,
376+
rewriter);
377+
}
378+
379+
private:
380+
SmallVector<Value>
381+
extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor,
382+
VectorType srcVectorType,
383+
ConversionPatternRewriter &rewriter) const {
372384
int numElements = srcVectorType.getDimSize(0);
373-
SmallVector<Value, 4> values;
385+
SmallVector<Value> values;
374386
values.reserve(numElements + (adaptor.getAcc() != nullptr));
375387
Location loc = reduceOp.getLoc();
376388
for (int i = 0; i < numElements; ++i) {
@@ -381,9 +393,26 @@ struct VectorReductionPattern final
381393
if (Value acc = adaptor.getAcc())
382394
values.push_back(acc);
383395

384-
// Reduce them.
385-
Value result = values.front();
386-
for (Value next : llvm::ArrayRef(values).drop_front()) {
396+
return values;
397+
}
398+
};
399+
400+
#define VECTOR_REDUCTION_BASE \
401+
VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, \
402+
SPIRVSMaxOp, SPIRVSMinOp>>
403+
template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
404+
typename SPIRVSMinOp>
405+
struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
406+
using Base = VECTOR_REDUCTION_BASE;
407+
using Base::Base;
408+
409+
LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
410+
ArrayRef<Value> extractedElements,
411+
Type resultType,
412+
ConversionPatternRewriter &rewriter) const {
413+
mlir::Location loc = reduceOp->getLoc();
414+
Value result = extractedElements.front();
415+
for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
387416
switch (reduceOp.getKind()) {
388417

389418
#define INT_AND_FLOAT_CASE(kind, iop, fop) \
@@ -403,10 +432,6 @@ struct VectorReductionPattern final
403432

404433
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
405434
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
406-
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
407-
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
408-
INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
409-
INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
410435
INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
411436
INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
412437
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,13 +441,57 @@ struct VectorReductionPattern final
416441
case vector::CombiningKind::OR:
417442
case vector::CombiningKind::XOR:
418443
return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
444+
default:
445+
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
419446
}
420447
}
421448

422449
rewriter.replaceOp(reduceOp, result);
423450
return success();
424451
}
425452
};
453+
#undef VECTOR_REDUCTION_BASE
454+
#undef INT_AND_FLOAT_CASE
455+
#undef INT_OR_FLOAT_CASE
456+
457+
#define MIN_MAX_PATTERN_BASE \
458+
VectorReductionPatternBase< \
459+
VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
460+
template <class SPIRVFMaxOp, class SPIRVFMinOp>
461+
struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
462+
using Base = MIN_MAX_PATTERN_BASE;
463+
using Base::Base;
464+
465+
LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
466+
ArrayRef<Value> extractedElements,
467+
Type resultType,
468+
ConversionPatternRewriter &rewriter) const {
469+
mlir::Location loc = reduceOp->getLoc();
470+
Value result = extractedElements.front();
471+
for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
472+
switch (reduceOp.getKind()) {
473+
474+
#define INT_OR_FLOAT_CASE(kind, fop) \
475+
case vector::CombiningKind::kind: \
476+
result = rewriter.create<fop>(loc, resultType, result, next); \
477+
break
478+
479+
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
480+
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
481+
INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
482+
INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
483+
484+
default:
485+
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
486+
}
487+
}
488+
489+
rewriter.replaceOp(reduceOp, result);
490+
return success();
491+
}
492+
};
493+
#undef MIN_MAX_PATTERN_BASE
494+
#undef INT_OR_FLOAT_CASE
426495

427496
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
428497
public:
@@ -604,25 +673,28 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
604673
};
605674

606675
} // namespace
607-
#define CL_MAX_MIN_OPS \
608-
spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
609-
spirv::CLSMaxOp, spirv::CLSMinOp
676+
#define CL_INT_MAX_MIN_OPS \
677+
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
678+
679+
#define GL_INT_MAX_MIN_OPS \
680+
spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
610681

611-
#define GL_MAX_MIN_OPS \
612-
spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
613-
spirv::GLSMaxOp, spirv::GLSMinOp
682+
#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
683+
#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
614684

615685
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
616686
RewritePatternSet &patterns) {
617-
patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
618-
VectorExtractElementOpConvert, VectorExtractOpConvert,
619-
VectorExtractStridedSliceOpConvert,
620-
VectorFmaOpConvert<spirv::GLFmaOp>,
621-
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
622-
VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
623-
VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
624-
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
625-
VectorSplatPattern>(typeConverter, patterns.getContext());
687+
patterns.add<
688+
VectorBitcastConvert, VectorBroadcastConvert,
689+
VectorExtractElementOpConvert, VectorExtractOpConvert,
690+
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
691+
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
692+
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
693+
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
694+
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
695+
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
696+
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
697+
VectorSplatPattern>(typeConverter, patterns.getContext());
626698
}
627699

628700
void mlir::populateVectorReductionToSPIRVDotProductPatterns(

0 commit comments

Comments
 (0)