@@ -715,150 +715,9 @@ MemRefType MemRefType::canonicalizeStridedLayout() {
715
715
return MemRefType::Builder (*this ).setLayout ({});
716
716
}
717
717
718
- // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
719
- // i.e. single term). Accumulate the AffineExpr into the existing one.
720
- static void extractStridesFromTerm (AffineExpr e,
721
- AffineExpr multiplicativeFactor,
722
- MutableArrayRef<AffineExpr> strides,
723
- AffineExpr &offset) {
724
- if (auto dim = dyn_cast<AffineDimExpr>(e))
725
- strides[dim.getPosition ()] =
726
- strides[dim.getPosition ()] + multiplicativeFactor;
727
- else
728
- offset = offset + e * multiplicativeFactor;
729
- }
730
-
731
- // / Takes a single AffineExpr `e` and populates the `strides` array with the
732
- // / strides expressions for each dim position.
733
- // / The convention is that the strides for dimensions d0, .. dn appear in
734
- // / order to make indexing intuitive into the result.
735
- static LogicalResult extractStrides (AffineExpr e,
736
- AffineExpr multiplicativeFactor,
737
- MutableArrayRef<AffineExpr> strides,
738
- AffineExpr &offset) {
739
- auto bin = dyn_cast<AffineBinaryOpExpr>(e);
740
- if (!bin) {
741
- extractStridesFromTerm (e, multiplicativeFactor, strides, offset);
742
- return success ();
743
- }
744
-
745
- if (bin.getKind () == AffineExprKind::CeilDiv ||
746
- bin.getKind () == AffineExprKind::FloorDiv ||
747
- bin.getKind () == AffineExprKind::Mod)
748
- return failure ();
749
-
750
- if (bin.getKind () == AffineExprKind::Mul) {
751
- auto dim = dyn_cast<AffineDimExpr>(bin.getLHS ());
752
- if (dim) {
753
- strides[dim.getPosition ()] =
754
- strides[dim.getPosition ()] + bin.getRHS () * multiplicativeFactor;
755
- return success ();
756
- }
757
- // LHS and RHS may both contain complex expressions of dims. Try one path
758
- // and if it fails try the other. This is guaranteed to succeed because
759
- // only one path may have a `dim`, otherwise this is not an AffineExpr in
760
- // the first place.
761
- if (bin.getLHS ().isSymbolicOrConstant ())
762
- return extractStrides (bin.getRHS (), multiplicativeFactor * bin.getLHS (),
763
- strides, offset);
764
- return extractStrides (bin.getLHS (), multiplicativeFactor * bin.getRHS (),
765
- strides, offset);
766
- }
767
-
768
- if (bin.getKind () == AffineExprKind::Add) {
769
- auto res1 =
770
- extractStrides (bin.getLHS (), multiplicativeFactor, strides, offset);
771
- auto res2 =
772
- extractStrides (bin.getRHS (), multiplicativeFactor, strides, offset);
773
- return success (succeeded (res1) && succeeded (res2));
774
- }
775
-
776
- llvm_unreachable (" unexpected binary operation" );
777
- }
778
-
779
- // / A stride specification is a list of integer values that are either static
780
- // / or dynamic (encoded with ShapedType::kDynamic). Strides encode
781
- // / the distance in the number of elements between successive entries along a
782
- // / particular dimension.
783
- // /
784
- // / For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
785
- // / non-contiguous memory region of `42` by `16` `f32` elements in which the
786
- // / distance between two consecutive elements along the outer dimension is `1`
787
- // / and the distance between two consecutive elements along the inner dimension
788
- // / is `64`.
789
- // /
790
- // / The convention is that the strides for dimensions d0, .. dn appear in
791
- // / order to make indexing intuitive into the result.
792
- static LogicalResult getStridesAndOffset (MemRefType t,
793
- SmallVectorImpl<AffineExpr> &strides,
794
- AffineExpr &offset) {
795
- AffineMap m = t.getLayout ().getAffineMap ();
796
-
797
- if (m.getNumResults () != 1 && !m.isIdentity ())
798
- return failure ();
799
-
800
- auto zero = getAffineConstantExpr (0 , t.getContext ());
801
- auto one = getAffineConstantExpr (1 , t.getContext ());
802
- offset = zero;
803
- strides.assign (t.getRank (), zero);
804
-
805
- // Canonical case for empty map.
806
- if (m.isIdentity ()) {
807
- // 0-D corner case, offset is already 0.
808
- if (t.getRank () == 0 )
809
- return success ();
810
- auto stridedExpr =
811
- makeCanonicalStridedLayoutExpr (t.getShape (), t.getContext ());
812
- if (succeeded (extractStrides (stridedExpr, one, strides, offset)))
813
- return success ();
814
- assert (false && " unexpected failure: extract strides in canonical layout" );
815
- }
816
-
817
- // Non-canonical case requires more work.
818
- auto stridedExpr =
819
- simplifyAffineExpr (m.getResult (0 ), m.getNumDims (), m.getNumSymbols ());
820
- if (failed (extractStrides (stridedExpr, one, strides, offset))) {
821
- offset = AffineExpr ();
822
- strides.clear ();
823
- return failure ();
824
- }
825
-
826
- // Simplify results to allow folding to constants and simple checks.
827
- unsigned numDims = m.getNumDims ();
828
- unsigned numSymbols = m.getNumSymbols ();
829
- offset = simplifyAffineExpr (offset, numDims, numSymbols);
830
- for (auto &stride : strides)
831
- stride = simplifyAffineExpr (stride, numDims, numSymbols);
832
-
833
- return success ();
834
- }
835
-
836
718
LogicalResult MemRefType::getStridesAndOffset (SmallVectorImpl<int64_t > &strides,
837
719
int64_t &offset) {
838
- // Happy path: the type uses the strided layout directly.
839
- if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout ())) {
840
- llvm::append_range (strides, strided.getStrides ());
841
- offset = strided.getOffset ();
842
- return success ();
843
- }
844
-
845
- // Otherwise, defer to the affine fallback as layouts are supposed to be
846
- // convertible to affine maps.
847
- AffineExpr offsetExpr;
848
- SmallVector<AffineExpr, 4 > strideExprs;
849
- if (failed (::getStridesAndOffset (*this , strideExprs, offsetExpr)))
850
- return failure ();
851
- if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
852
- offset = cst.getValue ();
853
- else
854
- offset = ShapedType::kDynamic ;
855
- for (auto e : strideExprs) {
856
- if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
857
- strides.push_back (c.getValue ());
858
- else
859
- strides.push_back (ShapedType::kDynamic );
860
- }
861
- return success ();
720
+ return getLayout ().getStridesAndOffset (getShape (), strides, offset);
862
721
}
863
722
864
723
std::pair<SmallVector<int64_t >, int64_t > MemRefType::getStridesAndOffset () {
0 commit comments