20
20
#include " mlir/Dialect/Vector/IR/VectorOps.h"
21
21
#include " mlir/IR/BuiltinAttributes.h"
22
22
#include " mlir/IR/BuiltinTypes.h"
23
+ #include " mlir/IR/Location.h"
23
24
#include " mlir/IR/Matchers.h"
24
25
#include " mlir/IR/PatternMatch.h"
25
26
#include " mlir/IR/TypeUtilities.h"
@@ -351,15 +352,13 @@ struct VectorInsertStridedSliceOpConvert final
351
352
}
352
353
};
353
354
354
- template <class SPIRVFMaxOp , class SPIRVFMinOp , class SPIRVUMaxOp ,
355
- class SPIRVUMinOp , class SPIRVSMaxOp , class SPIRVSMinOp >
356
- struct VectorReductionPattern final
357
- : public OpConversionPattern<vector::ReductionOp> {
355
+ template <typename Derived>
356
+ struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
358
357
using OpConversionPattern::OpConversionPattern;
359
358
360
359
LogicalResult
361
360
matchAndRewrite (vector::ReductionOp reduceOp, OpAdaptor adaptor,
362
- ConversionPatternRewriter &rewriter) const override {
361
+ ConversionPatternRewriter &rewriter) const final {
363
362
Type resultType = typeConverter->convertType (reduceOp.getType ());
364
363
if (!resultType)
365
364
return failure ();
@@ -368,9 +367,22 @@ struct VectorReductionPattern final
368
367
if (!srcVectorType || srcVectorType.getRank () != 1 )
369
368
return rewriter.notifyMatchFailure (reduceOp, " not 1-D vector source" );
370
369
371
- // Extract all elements.
370
+ SmallVector<Value> extractedElements =
371
+ extractAllElements (reduceOp, adaptor, srcVectorType, rewriter);
372
+
373
+ const auto &self = static_cast <const Derived &>(*this );
374
+
375
+ return self.reduceExtracted (reduceOp, extractedElements, resultType,
376
+ rewriter);
377
+ }
378
+
379
+ private:
380
+ SmallVector<Value>
381
+ extractAllElements (vector::ReductionOp reduceOp, OpAdaptor adaptor,
382
+ VectorType srcVectorType,
383
+ ConversionPatternRewriter &rewriter) const {
372
384
int numElements = srcVectorType.getDimSize (0 );
373
- SmallVector<Value, 4 > values;
385
+ SmallVector<Value> values;
374
386
values.reserve (numElements + (adaptor.getAcc () != nullptr ));
375
387
Location loc = reduceOp.getLoc ();
376
388
for (int i = 0 ; i < numElements; ++i) {
@@ -381,9 +393,26 @@ struct VectorReductionPattern final
381
393
if (Value acc = adaptor.getAcc ())
382
394
values.push_back (acc);
383
395
384
- // Reduce them.
385
- Value result = values.front ();
386
- for (Value next : llvm::ArrayRef (values).drop_front ()) {
396
+ return values;
397
+ }
398
+ };
399
+
400
+ #define VECTOR_REDUCTION_BASE \
401
+ VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, \
402
+ SPIRVSMaxOp, SPIRVSMinOp>>
403
+ template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
404
+ typename SPIRVSMinOp>
405
+ struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
406
+ using Base = VECTOR_REDUCTION_BASE;
407
+ using Base::Base;
408
+
409
+ LogicalResult reduceExtracted (vector::ReductionOp reduceOp,
410
+ ArrayRef<Value> extractedElements,
411
+ Type resultType,
412
+ ConversionPatternRewriter &rewriter) const {
413
+ mlir::Location loc = reduceOp->getLoc ();
414
+ Value result = extractedElements.front ();
415
+ for (Value next : llvm::ArrayRef (extractedElements).drop_front ()) {
387
416
switch (reduceOp.getKind ()) {
388
417
389
418
#define INT_AND_FLOAT_CASE (kind, iop, fop ) \
@@ -403,10 +432,6 @@ struct VectorReductionPattern final
403
432
404
433
INT_AND_FLOAT_CASE (ADD, IAddOp, FAddOp);
405
434
INT_AND_FLOAT_CASE (MUL, IMulOp, FMulOp);
406
- INT_OR_FLOAT_CASE (MAXIMUMF, SPIRVFMaxOp);
407
- INT_OR_FLOAT_CASE (MINIMUMF, SPIRVFMinOp);
408
- INT_OR_FLOAT_CASE (MAXF, SPIRVFMaxOp);
409
- INT_OR_FLOAT_CASE (MINF, SPIRVFMinOp);
410
435
INT_OR_FLOAT_CASE (MINUI, SPIRVUMinOp);
411
436
INT_OR_FLOAT_CASE (MINSI, SPIRVSMinOp);
412
437
INT_OR_FLOAT_CASE (MAXUI, SPIRVUMaxOp);
@@ -416,13 +441,57 @@ struct VectorReductionPattern final
416
441
case vector::CombiningKind::OR:
417
442
case vector::CombiningKind::XOR:
418
443
return rewriter.notifyMatchFailure (reduceOp, " unimplemented" );
444
+ default :
445
+ return rewriter.notifyMatchFailure (reduceOp, " not handled here" );
419
446
}
420
447
}
421
448
422
449
rewriter.replaceOp (reduceOp, result);
423
450
return success ();
424
451
}
425
452
};
453
+ #undef VECTOR_REDUCTION_BASE
454
+ #undef INT_AND_FLOAT_CASE
455
+ #undef INT_OR_FLOAT_CASE
456
+
457
+ #define MIN_MAX_PATTERN_BASE \
458
+ VectorReductionPatternBase< \
459
+ VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
460
+ template <class SPIRVFMaxOp , class SPIRVFMinOp >
461
+ struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
462
+ using Base = MIN_MAX_PATTERN_BASE;
463
+ using Base::Base;
464
+
465
+ LogicalResult reduceExtracted (vector::ReductionOp reduceOp,
466
+ ArrayRef<Value> extractedElements,
467
+ Type resultType,
468
+ ConversionPatternRewriter &rewriter) const {
469
+ mlir::Location loc = reduceOp->getLoc ();
470
+ Value result = extractedElements.front ();
471
+ for (Value next : llvm::ArrayRef (extractedElements).drop_front ()) {
472
+ switch (reduceOp.getKind ()) {
473
+
474
+ #define INT_OR_FLOAT_CASE (kind, fop ) \
475
+ case vector::CombiningKind::kind: \
476
+ result = rewriter.create <fop>(loc, resultType, result, next); \
477
+ break
478
+
479
+ INT_OR_FLOAT_CASE (MAXIMUMF, SPIRVFMaxOp);
480
+ INT_OR_FLOAT_CASE (MINIMUMF, SPIRVFMinOp);
481
+ INT_OR_FLOAT_CASE (MAXF, SPIRVFMaxOp);
482
+ INT_OR_FLOAT_CASE (MINF, SPIRVFMinOp);
483
+
484
+ default :
485
+ return rewriter.notifyMatchFailure (reduceOp, " not handled here" );
486
+ }
487
+ }
488
+
489
+ rewriter.replaceOp (reduceOp, result);
490
+ return success ();
491
+ }
492
+ };
493
+ #undef MIN_MAX_PATTERN_BASE
494
+ #undef INT_OR_FLOAT_CASE
426
495
427
496
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
428
497
public:
@@ -604,25 +673,28 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
604
673
};
605
674
606
675
} // namespace
607
- #define CL_MAX_MIN_OPS \
608
- spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
609
- spirv::CLSMaxOp, spirv::CLSMinOp
676
+ #define CL_INT_MAX_MIN_OPS \
677
+ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
678
+
679
+ #define GL_INT_MAX_MIN_OPS \
680
+ spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
610
681
611
- #define GL_MAX_MIN_OPS \
612
- spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
613
- spirv::GLSMaxOp, spirv::GLSMinOp
682
+ #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
683
+ #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
614
684
615
685
void mlir::populateVectorToSPIRVPatterns (SPIRVTypeConverter &typeConverter,
616
686
RewritePatternSet &patterns) {
617
- patterns.add <VectorBitcastConvert, VectorBroadcastConvert,
618
- VectorExtractElementOpConvert, VectorExtractOpConvert,
619
- VectorExtractStridedSliceOpConvert,
620
- VectorFmaOpConvert<spirv::GLFmaOp>,
621
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
622
- VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
623
- VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
624
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
625
- VectorSplatPattern>(typeConverter, patterns.getContext ());
687
+ patterns.add <
688
+ VectorBitcastConvert, VectorBroadcastConvert,
689
+ VectorExtractElementOpConvert, VectorExtractOpConvert,
690
+ VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
691
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
692
+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
693
+ VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
694
+ VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
695
+ VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
696
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
697
+ VectorSplatPattern>(typeConverter, patterns.getContext ());
626
698
}
627
699
628
700
void mlir::populateVectorReductionToSPIRVDotProductPatterns (
0 commit comments