20
20
#include " mlir/Dialect/Vector/IR/VectorOps.h"
21
21
#include " mlir/IR/BuiltinAttributes.h"
22
22
#include " mlir/IR/BuiltinTypes.h"
23
+ #include " mlir/IR/Location.h"
23
24
#include " mlir/IR/Matchers.h"
24
25
#include " mlir/IR/PatternMatch.h"
25
26
#include " mlir/IR/TypeUtilities.h"
26
27
#include " mlir/Support/LogicalResult.h"
27
28
#include " mlir/Transforms/DialectConversion.h"
28
29
#include " llvm/ADT/ArrayRef.h"
29
30
#include " llvm/ADT/STLExtras.h"
31
+ #include " llvm/ADT/SmallVector.h"
30
32
#include " llvm/ADT/SmallVectorExtras.h"
31
33
#include " llvm/Support/FormatVariadic.h"
32
34
#include < cassert>
@@ -351,39 +353,64 @@ struct VectorInsertStridedSliceOpConvert final
351
353
}
352
354
};
353
355
354
- template <class SPIRVFMaxOp , class SPIRVFMinOp , class SPIRVUMaxOp ,
355
- class SPIRVUMinOp , class SPIRVSMaxOp , class SPIRVSMinOp >
356
- struct VectorReductionPattern final
357
- : public OpConversionPattern<vector::ReductionOp> {
356
+ static SmallVector<Value> extractAllElements (
357
+ vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
358
+ VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
359
+ int numElements = static_cast <int >(srcVectorType.getDimSize (0 ));
360
+ SmallVector<Value> values;
361
+ values.reserve (numElements + (adaptor.getAcc () ? 1 : 0 ));
362
+ Location loc = reduceOp.getLoc ();
363
+
364
+ for (int i = 0 ; i < numElements; ++i) {
365
+ values.push_back (rewriter.create <spirv::CompositeExtractOp>(
366
+ loc, srcVectorType.getElementType (), adaptor.getVector (),
367
+ rewriter.getI32ArrayAttr ({i})));
368
+ }
369
+ if (Value acc = adaptor.getAcc ())
370
+ values.push_back (acc);
371
+
372
+ return values;
373
+ }
374
+
375
+ struct ReductionRewriteInfo {
376
+ Type resultType;
377
+ SmallVector<Value> extractedElements;
378
+ };
379
+
380
+ FailureOr<ReductionRewriteInfo> static getReductionInfo (
381
+ vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
382
+ ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
383
+ Type resultType = typeConverter.convertType (op.getType ());
384
+ if (!resultType)
385
+ return failure ();
386
+
387
+ auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector ().getType ());
388
+ if (!srcVectorType || srcVectorType.getRank () != 1 )
389
+ return rewriter.notifyMatchFailure (op, " not a 1-D vector source" );
390
+
391
+ SmallVector<Value> extractedElements =
392
+ extractAllElements (op, adaptor, srcVectorType, rewriter);
393
+
394
+ return ReductionRewriteInfo{resultType, std::move (extractedElements)};
395
+ }
396
+
397
+ template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
398
+ typename SPIRVSMinOp>
399
+ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
358
400
using OpConversionPattern::OpConversionPattern;
359
401
360
402
LogicalResult
361
403
matchAndRewrite (vector::ReductionOp reduceOp, OpAdaptor adaptor,
362
404
ConversionPatternRewriter &rewriter) const override {
363
- Type resultType = typeConverter->convertType (reduceOp.getType ());
364
- if (!resultType)
405
+ auto reductionInfo =
406
+ getReductionInfo (reduceOp, adaptor, rewriter, *getTypeConverter ());
407
+ if (failed (reductionInfo))
365
408
return failure ();
366
409
367
- auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector ().getType ());
368
- if (!srcVectorType || srcVectorType.getRank () != 1 )
369
- return rewriter.notifyMatchFailure (reduceOp, " not 1-D vector source" );
370
-
371
- // Extract all elements.
372
- int numElements = srcVectorType.getDimSize (0 );
373
- SmallVector<Value, 4 > values;
374
- values.reserve (numElements + (adaptor.getAcc () != nullptr ));
375
- Location loc = reduceOp.getLoc ();
376
- for (int i = 0 ; i < numElements; ++i) {
377
- values.push_back (rewriter.create <spirv::CompositeExtractOp>(
378
- loc, srcVectorType.getElementType (), adaptor.getVector (),
379
- rewriter.getI32ArrayAttr ({i})));
380
- }
381
- if (Value acc = adaptor.getAcc ())
382
- values.push_back (acc);
383
-
384
- // Reduce them.
385
- Value result = values.front ();
386
- for (Value next : llvm::ArrayRef (values).drop_front ()) {
410
+ auto [resultType, extractedElements] = *reductionInfo;
411
+ Location loc = reduceOp->getLoc ();
412
+ Value result = extractedElements.front ();
413
+ for (Value next : llvm::drop_begin (extractedElements)) {
387
414
switch (reduceOp.getKind ()) {
388
415
389
416
#define INT_AND_FLOAT_CASE (kind, iop, fop ) \
@@ -403,10 +430,6 @@ struct VectorReductionPattern final
403
430
404
431
INT_AND_FLOAT_CASE (ADD, IAddOp, FAddOp);
405
432
INT_AND_FLOAT_CASE (MUL, IMulOp, FMulOp);
406
- INT_OR_FLOAT_CASE (MAXIMUMF, SPIRVFMaxOp);
407
- INT_OR_FLOAT_CASE (MINIMUMF, SPIRVFMinOp);
408
- INT_OR_FLOAT_CASE (MAXF, SPIRVFMaxOp);
409
- INT_OR_FLOAT_CASE (MINF, SPIRVFMinOp);
410
433
INT_OR_FLOAT_CASE (MINUI, SPIRVUMinOp);
411
434
INT_OR_FLOAT_CASE (MINSI, SPIRVSMinOp);
412
435
INT_OR_FLOAT_CASE (MAXUI, SPIRVUMaxOp);
@@ -416,7 +439,51 @@ struct VectorReductionPattern final
416
439
case vector::CombiningKind::OR:
417
440
case vector::CombiningKind::XOR:
418
441
return rewriter.notifyMatchFailure (reduceOp, " unimplemented" );
442
+ default :
443
+ return rewriter.notifyMatchFailure (reduceOp, " not handled here" );
419
444
}
445
+ #undef INT_AND_FLOAT_CASE
446
+ #undef INT_OR_FLOAT_CASE
447
+ }
448
+
449
+ rewriter.replaceOp (reduceOp, result);
450
+ return success ();
451
+ }
452
+ };
453
+
454
+ template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
455
+ struct VectorReductionFloatMinMax final
456
+ : OpConversionPattern<vector::ReductionOp> {
457
+ using OpConversionPattern::OpConversionPattern;
458
+
459
+ LogicalResult
460
+ matchAndRewrite (vector::ReductionOp reduceOp, OpAdaptor adaptor,
461
+ ConversionPatternRewriter &rewriter) const override {
462
+ auto reductionInfo =
463
+ getReductionInfo (reduceOp, adaptor, rewriter, *getTypeConverter ());
464
+ if (failed (reductionInfo))
465
+ return failure ();
466
+
467
+ auto [resultType, extractedElements] = *reductionInfo;
468
+ Location loc = reduceOp->getLoc ();
469
+ Value result = extractedElements.front ();
470
+ for (Value next : llvm::drop_begin (extractedElements)) {
471
+ switch (reduceOp.getKind ()) {
472
+
473
+ #define INT_OR_FLOAT_CASE (kind, fop ) \
474
+ case vector::CombiningKind::kind: \
475
+ result = rewriter.create <fop>(loc, resultType, result, next); \
476
+ break
477
+
478
+ INT_OR_FLOAT_CASE (MAXIMUMF, SPIRVFMaxOp);
479
+ INT_OR_FLOAT_CASE (MINIMUMF, SPIRVFMinOp);
480
+ INT_OR_FLOAT_CASE (MAXF, SPIRVFMaxOp);
481
+ INT_OR_FLOAT_CASE (MINF, SPIRVFMinOp);
482
+
483
+ default :
484
+ return rewriter.notifyMatchFailure (reduceOp, " not handled here" );
485
+ }
486
+ #undef INT_OR_FLOAT_CASE
420
487
}
421
488
422
489
rewriter.replaceOp (reduceOp, result);
@@ -674,13 +741,14 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
674
741
};
675
742
676
743
} // namespace
677
- #define CL_MAX_MIN_OPS \
678
- spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
679
- spirv::CLSMaxOp, spirv::CLSMinOp
744
+ #define CL_INT_MAX_MIN_OPS \
745
+ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
746
+
747
+ #define GL_INT_MAX_MIN_OPS \
748
+ spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
680
749
681
- #define GL_MAX_MIN_OPS \
682
- spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
683
- spirv::GLSMaxOp, spirv::GLSMinOp
750
+ #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
751
+ #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
684
752
685
753
void mlir::populateVectorToSPIRVPatterns (SPIRVTypeConverter &typeConverter,
686
754
RewritePatternSet &patterns) {
@@ -689,8 +757,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
689
757
VectorExtractElementOpConvert, VectorExtractOpConvert,
690
758
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
691
759
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
692
- VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
693
- VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
760
+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
761
+ VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
762
+ VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
763
+ VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
694
764
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
695
765
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
696
766
typeConverter, patterns.getContext ());
0 commit comments