@@ -611,11 +611,11 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
611
611
612
612
const size_t numIndices = extractOp.getIndices ().size ();
613
613
for (size_t i = 1 ; i < numIndices; i++) {
614
+ Value dimIdx = rewriter.create <arith::ConstantIndexOp>(loc, i);
615
+
614
616
auto dimSize = broadcastIfNeeded (
615
617
rewriter,
616
- rewriter.create <arith::ConstantIndexOp>(
617
- loc,
618
- extractOp.getTensor ().getType ().cast <ShapedType>().getDimSize (i)),
618
+ rewriter.create <tensor::DimOp>(loc, extractOp.getTensor (), dimIdx),
619
619
indexVecType.getShape ());
620
620
621
621
offset = rewriter.create <arith::MulIOp>(loc, offset, dimSize);
@@ -630,6 +630,143 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
630
630
return offset;
631
631
}
632
632
633
+ enum VectorMemoryAccessKind {
634
+ // TODO: ScalarBroadcast,
635
+ Contiguous,
636
+ Gather
637
+ };
638
+
639
+ // / Check whether /p val can be used for calculating an index for a contiguous
640
+ // / load operation, i.e. whether /p val:
641
+ // / * is invariant with respect to /p linalgOp, i.e. whether it remains
642
+ // / constant for all iterations, and
643
+ // / * increments with the loop iterator (when /p strideZero is false) or is
644
+ // / not affected by the loop indices (/p strideZero is true).
645
+ static bool isContiguousLoadIdx (LinalgOp &linalgOp, Value &val, size_t dim,
646
+ bool strideZero) {
647
+ auto *block = linalgOp.getBlock ();
648
+
649
+ // Bail out if this is a block argument for this linalg.generic Op.
650
+ // TODO: We could try analysing the corresponding affine map here.
651
+ if (val.dyn_cast <BlockArgument>())
652
+ return llvm::all_of (block->getArguments (),
653
+ [&val](Value v) { return (v != val); });
654
+
655
+ Operation *defOp = val.getDefiningOp ();
656
+ assert (defOp && " This is neither a block argument nor an operation result" );
657
+
658
+ // Given the assumption on the shape of the target tensor, index Op is
659
+ // either:
660
+ // * constant (for non-trailing dims), or
661
+ // * increments with stride one together with the trailing dimension
662
+ // Both cases are fine for contigious loads.
663
+ if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
664
+ return strideZero ? (indexOp.getDim () != dim) : (indexOp.getDim () == dim);
665
+
666
+ auto *ancestor = block->findAncestorOpInBlock (*defOp);
667
+
668
+ // Values define outside `linalgOp`.
669
+ if (!ancestor)
670
+ return true ;
671
+
672
+ // Values defined inside `linalgOp`, which are constant.
673
+ if (dyn_cast<arith::ConstantOp>(ancestor))
674
+ return true ;
675
+
676
+ bool result = true ;
677
+ for (auto op : ancestor->getOperands ())
678
+ result &= isContiguousLoadIdx (linalgOp, op, dim, strideZero);
679
+
680
+ return result;
681
+ }
682
+
683
+ // / Check whether the calculation of \p val is based on linalg.index Op with
684
+ // / the dim attribute matching \p dim.
685
+ static bool isBasedOnIndexOp (LinalgOp &linalgOp, Value &val, size_t dim) {
686
+ auto *block = linalgOp.getBlock ();
687
+ auto targetShape = linalgOp.getStaticLoopRanges ();
688
+
689
+ if (val.isa <BlockArgument>())
690
+ return false ;
691
+
692
+ Operation *defOp = val.getDefiningOp ();
693
+ assert (defOp && " This is neither a block argument nor an operation result" );
694
+
695
+ if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
696
+ return (indexOp.getDim () == dim);
697
+
698
+ auto *ancestor = block->findAncestorOpInBlock (*defOp);
699
+
700
+ if (!ancestor)
701
+ return false ;
702
+
703
+ bool result = false ;
704
+ for (auto op : ancestor->getOperands ())
705
+ result |= isBasedOnIndexOp (linalgOp, op, dim);
706
+
707
+ return result;
708
+ }
709
+
710
+ // / Check whether \p extractOp would be a gather or a contiguous load Op after
711
+ // / vectorising \p linalgOp. Note that it is always safe to use gather load
712
+ // / operations for contiguous loads (albeit slow), but not vice-versa. When in
713
+ // / doubt, bail out and assume that \p extractOp is a gather load.
714
+ static VectorMemoryAccessKind
715
+ getTensorExtractMemoryAccessPattern (tensor::ExtractOp extractOp,
716
+ LinalgOp &linalgOp) {
717
+
718
+ auto targetShape = linalgOp.getStaticLoopRanges ();
719
+
720
+ // Assume that it's a gather load when reading _into_:
721
+ // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
722
+ // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
723
+ // TODO: Relax these conditions.
724
+ if ((llvm::count_if (targetShape,
725
+ [](int64_t dimSize) { return dimSize > 1 ; }) != 1 ) ||
726
+ targetShape.back () == 1 )
727
+ return VectorMemoryAccessKind::Gather;
728
+
729
+ auto inputShape = extractOp.getTensor ().getType ().cast <ShapedType>();
730
+
731
+ // Assume that it's a gather load when reading _from_ a tensor for which the
732
+ // trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
733
+ // TODO: Relax this condition.
734
+ if (inputShape.getShape ().back () == 1 )
735
+ return VectorMemoryAccessKind::Gather;
736
+
737
+ bool isContiguous = true ;
738
+
739
+ // Iterate over all indices. Analyze whether the way each index is calculate
740
+ // is suitable for contiguous load operations (e.g. loop invariant).
741
+ auto indices = extractOp.getIndices ();
742
+ for (auto [i, indexVal] : llvm::enumerate (indices)) {
743
+ if (inputShape.getShape ()[i] == 1 ) {
744
+ // This extractOp index must be a loop-invariant constant
745
+ continue ;
746
+ }
747
+
748
+ auto extractOpBottomIdx = indices.size () - 1 ;
749
+ auto strideOneDim = targetShape.size () - 1 ;
750
+ bool strideZero = (i != extractOpBottomIdx);
751
+ isContiguous &=
752
+ isContiguousLoadIdx (linalgOp, indexVal, strideOneDim, strideZero);
753
+ }
754
+
755
+ // The calculation of the trailing index must include the loop index. Given
756
+ // the assumption on the output tensor (which is defined by the iteration
757
+ // space), only the trailing dim matters.
758
+ auto extractOpTrailingIdx = indices.back ();
759
+ isContiguous &=
760
+ isBasedOnIndexOp (linalgOp, extractOpTrailingIdx, targetShape.size () - 1 );
761
+
762
+ if (isContiguous) {
763
+ LDBG (" Found contigous load: " << extractOp);
764
+ return VectorMemoryAccessKind::Contiguous;
765
+ }
766
+
767
+ return VectorMemoryAccessKind::Gather;
768
+ }
769
+
633
770
// / Helper function to vectorize the tensor.extract operations. Returns
634
771
// / VectorizationStatus::NewOp to signal the vectorization algorithm that it
635
772
// / should map the produced operations. This function is meant to be used as a
@@ -660,15 +797,64 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
660
797
extractOp.getIndices ().size (),
661
798
rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
662
799
663
- Value offset = calculateGatherOffset (rewriter, extractOp, bvm, targetShape);
800
+ VectorMemoryAccessKind memAccessKind =
801
+ getTensorExtractMemoryAccessPattern (extractOp, linalgOp);
802
+
803
+ // 1. Handle gather access
804
+ if (memAccessKind == VectorMemoryAccessKind::Gather) {
805
+ Value offset = calculateGatherOffset (rewriter, extractOp, bvm, targetShape);
806
+
807
+ // Generate the gather load
808
+ Operation *gatherOp = rewriter.create <vector::GatherOp>(
809
+ loc, resultType, extractOp.getTensor (), baseIndices, offset,
810
+ maskConstantOp, passThruConstantOp);
811
+ gatherOp = state.maskOperation (rewriter, gatherOp, linalgOp);
812
+
813
+ LDBG (" Vectorised as gather load: " << extractOp);
814
+ return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
815
+ }
816
+
817
+ // 2. Handle contiguous access.
818
+ SmallVector<Value> transferReadIdxs;
819
+ auto resTrailingDim = resultType.getShape ().back ();
820
+ auto zero = rewriter.create <arith::ConstantOp>(
821
+ loc, rewriter.getI32Type (), rewriter.getZeroAttr (rewriter.getI32Type ()));
822
+
823
+ // Collect indices for `vector.transfer_read`. At this point, the indices will
824
+ // either be scalars or would have been broadcast to vectors matching the
825
+ // result type. For indices that are vectors, there are two options:
826
+ // * for non-trailing indices, all elements are identical (contiguous
827
+ // loads are identified by looking for non-trailing indices that are
828
+ // invariant with respect to the corresponding linalg.generic), or
829
+ // * for trailing indices, the index vector will contain values with stride
830
+ // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
831
+ // needed.
832
+ // This means that
833
+ // * for scalar indices - just re-use it,
834
+ // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
835
+ // (0th) element and use that.
836
+ for (size_t i = 0 ; i < extractOp.getIndices ().size (); i++) {
837
+ auto idx = bvm.lookup (extractOp.getIndices ()[i]);
838
+ if (idx.getType ().isIndex ()) {
839
+ transferReadIdxs.push_back (idx);
840
+ continue ;
841
+ }
842
+
843
+ auto indexAs1dVector = rewriter.create <vector::ShapeCastOp>(
844
+ loc, VectorType::get ({resTrailingDim}, rewriter.getIndexType ()),
845
+ bvm.lookup (extractOp.getIndices ()[i]));
846
+ transferReadIdxs.push_back (
847
+ rewriter.create <vector::ExtractElementOp>(loc, indexAs1dVector, zero));
848
+ }
849
+
850
+ // `tensor.extract_element` is always in-bounds, hence the following holds.
851
+ SmallVector<bool > inBounds (resultType.getRank (), true );
664
852
665
- // Generate the gather load
666
- Operation *gatherOp = rewriter.create <vector::GatherOp>(
667
- loc, resultType, extractOp.getTensor (), baseIndices, offset,
668
- maskConstantOp, passThruConstantOp);
669
- gatherOp = state.maskOperation (rewriter, gatherOp, linalgOp);
853
+ auto transferReadOp = rewriter.create <vector::TransferReadOp>(
854
+ loc, resultType, extractOp.getTensor (), transferReadIdxs, inBounds);
670
855
671
- return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
856
+ LDBG (" Vectorised as contiguous load: " << extractOp);
857
+ return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
672
858
}
673
859
674
860
// / Emit reduction operations if the shapes of the value to reduce is different
0 commit comments