@@ -618,6 +618,74 @@ struct VectorInterleaveOpConvert final
618
618
}
619
619
};
620
620
621
+ struct VectorDeinterleaveOpConvert final
622
+ : public OpConversionPattern<vector::DeinterleaveOp> {
623
+ using OpConversionPattern::OpConversionPattern;
624
+
625
+ LogicalResult
626
+ matchAndRewrite (vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
627
+ ConversionPatternRewriter &rewriter) const override {
628
+
629
+ // Check the result vector type.
630
+ VectorType oldResultType = deinterleaveOp.getResultVectorType ();
631
+ Type newResultType = getTypeConverter ()->convertType (oldResultType);
632
+ if (!newResultType)
633
+ return rewriter.notifyMatchFailure (deinterleaveOp,
634
+ " unsupported result vector type" );
635
+
636
+ // Get location.
637
+ Location loc = deinterleaveOp->getLoc ();
638
+
639
+ // Deinterleave the indices.
640
+ VectorType sourceType = deinterleaveOp.getSourceVectorType ();
641
+ int n = sourceType.getNumElements ();
642
+
643
+ // Output vectors of size 1 are converted to scalars by the type converter.
644
+ // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
645
+ // use `spirv::CompositeExtractOp`.
646
+ if (n == 2 ) {
647
+ spirv::CompositeExtractOp compositeExtractZero =
648
+ rewriter.create <spirv::CompositeExtractOp>(
649
+ loc, newResultType, adaptor.getSource (),
650
+ rewriter.getI32ArrayAttr ({0 }));
651
+
652
+ spirv::CompositeExtractOp compositeExtractOne =
653
+ rewriter.create <spirv::CompositeExtractOp>(
654
+ loc, newResultType, adaptor.getSource (),
655
+ rewriter.getI32ArrayAttr ({1 }));
656
+
657
+ rewriter.replaceOp (deinterleaveOp,
658
+ {compositeExtractZero, compositeExtractOne});
659
+ return success ();
660
+ }
661
+
662
+ // Indices for `res1`.
663
+ auto seqEven = llvm::seq<int64_t >(n / 2 );
664
+ auto indicesEven =
665
+ llvm::map_to_vector (seqEven, [](int i) { return i * 2 ; });
666
+
667
+ // Indices for `res2`.
668
+ auto seqOdd = llvm::seq<int64_t >(n / 2 );
669
+ auto indicesOdd =
670
+ llvm::map_to_vector (seqOdd, [](int i) { return i * 2 + 1 ; });
671
+
672
+ // Create two SPIR-V shuffles.
673
+ spirv::VectorShuffleOp shuffleEven =
674
+ rewriter.create <spirv::VectorShuffleOp>(
675
+ loc, newResultType, adaptor.getSource (), adaptor.getSource (),
676
+ rewriter.getI32ArrayAttr (indicesEven));
677
+
678
+ spirv::VectorShuffleOp shuffleOdd = rewriter.create <spirv::VectorShuffleOp>(
679
+ loc, newResultType, adaptor.getSource (), adaptor.getSource (),
680
+ rewriter.getI32ArrayAttr (indicesOdd));
681
+
682
+ // Replace deinterleaveOp with SPIR-V shuffles.
683
+ rewriter.replaceOp (deinterleaveOp, {shuffleEven, shuffleOdd});
684
+
685
+ return success ();
686
+ }
687
+ };
688
+
621
689
struct VectorLoadOpConverter final
622
690
: public OpConversionPattern<vector::LoadOp> {
623
691
using OpConversionPattern::OpConversionPattern;
@@ -862,9 +930,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
862
930
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
863
931
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
864
932
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
865
- VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter ,
866
- VectorStoreOpConverter>(typeConverter, patterns. getContext (),
867
- PatternBenefit (1 ));
933
+ VectorInterleaveOpConvert, VectorDeinterleaveOpConvert ,
934
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
935
+ typeConverter, patterns. getContext (), PatternBenefit (1 ));
868
936
869
937
// Make sure that the more specialized dot product pattern has higher benefit
870
938
// than the generic one that extracts all elements.
0 commit comments