15
15
#include " mlir/Dialect/MemRef/IR/MemRef.h"
16
16
#include " mlir/Dialect/MemRef/Transforms/Passes.h"
17
17
#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
18
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18
19
#include " mlir/Dialect/Vector/IR/VectorOps.h"
19
20
#include " mlir/Transforms/DialectConversion.h"
20
21
#include " llvm/Support/FormatVariadic.h"
@@ -27,102 +28,6 @@ using namespace mlir;
27
28
// Utility functions
28
29
// ===----------------------------------------------------------------------===//
29
30
30
- // / The emulation only works on 1D memref types.
31
- // / To make this work on N-D memref, we need to linearize the offset.
32
- // /
33
- // / For example, to emulate i4 to i8, the following op:
34
- // /
35
- // / %0 = memref.load %arg0[%v0, %v1] :
36
- // / memref<?x?xi4, strided<[?, ?], offset: ?>>
37
- // /
38
- // / can be replaced with
39
- // /
40
- // / %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
41
- // /
42
- // / %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
43
- // / %linearized_size = %size0 * %size1
44
- // / %scaled_linear_offset = %linearized_offset / 8 * 4
45
- // / %scaled_base_offset = %offset / 8 * 4
46
- // /
47
- // / %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
48
- // / sizes = [%linearized_size], strides = [%stride#1]
49
- // /
50
- // / %new_load = memref.load %linearized[%scaled_linear_offset] :
51
- // / memref<?xi8, strided<[?], offset: ?>>
52
-
53
- static Value
54
- linearizeMemrefLoad (Location loc, MemRefType sourceType, int srcBits,
55
- int dstBits, SmallVector<Value> indices,
56
- memref::ExtractStridedMetadataOp stridedMetadata,
57
- OpBuilder &builder) {
58
- auto srcElementType = sourceType.getElementType ();
59
- unsigned sourceRank = indices.size ();
60
-
61
- Value baseBuffer = stridedMetadata.getBaseBuffer ();
62
- SmallVector<Value> baseSizes = stridedMetadata.getSizes ();
63
- SmallVector<Value> baseStrides = stridedMetadata.getStrides ();
64
- Value baseOffset = stridedMetadata.getOffset ();
65
- assert (indices.size () == baseStrides.size ());
66
-
67
- // Create the affine symbols and values for linearization.
68
- SmallVector<AffineExpr> symbols (2 * sourceRank + 2 );
69
- bindSymbolsList (builder.getContext (), MutableArrayRef{symbols});
70
- symbols[0 ] = builder.getAffineSymbolExpr (0 );
71
- AffineExpr addMulMap = symbols.front ();
72
- AffineExpr mulMap = symbols.front ();
73
-
74
- SmallVector<OpFoldResult> offsetValues (2 * sourceRank + 2 );
75
- offsetValues[0 ] = builder.getIndexAttr (0 );
76
- SmallVector<OpFoldResult> sizeValues (sourceRank + 1 );
77
- sizeValues[0 ] = builder.getIndexAttr (1 );
78
-
79
- for (unsigned i = 0 ; i < sourceRank; ++i) {
80
- unsigned offsetIdx = 2 * i + 1 ;
81
- addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1 ];
82
- offsetValues[offsetIdx] = indices[i];
83
- offsetValues[offsetIdx + 1 ] = baseStrides[i];
84
-
85
- unsigned sizeIdx = i + 1 ;
86
- mulMap = mulMap * symbols[sizeIdx];
87
- sizeValues[sizeIdx] = baseSizes[i];
88
- }
89
-
90
- // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
91
- OpFoldResult scaler = builder.getIndexAttr (dstBits / srcBits);
92
- AffineExpr scaledAddMulMap = addMulMap.floorDiv (symbols.back ());
93
- offsetValues.back () = scaler;
94
-
95
- OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply (
96
- builder, loc, scaledAddMulMap, offsetValues);
97
- OpFoldResult linearizedSize =
98
- affine::makeComposedFoldedAffineApply (builder, loc, mulMap, sizeValues);
99
-
100
- // Adjust baseOffset by the scale factor (dstBits / srcBits).
101
- AffineExpr s0, s1;
102
- bindSymbols (builder.getContext (), s0, s1);
103
- OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply (
104
- builder, loc, s0.floorDiv (s1), {baseOffset, scaler});
105
-
106
- // Flatten n-D MemRef to 1-D MemRef.
107
- auto layoutAttr = StridedLayoutAttr::get (
108
- sourceType.getContext (), ShapedType::kDynamic , {ShapedType::kDynamic });
109
- int64_t staticShape = sourceType.hasStaticShape ()
110
- ? sourceType.getNumElements ()
111
- : ShapedType::kDynamic ;
112
- auto flattenMemrefType = MemRefType::get (
113
- staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace ());
114
-
115
- auto reinterpret = builder.create <memref::ReinterpretCastOp>(
116
- loc, flattenMemrefType, baseBuffer,
117
- getValueOrCreateConstantIndexOp (builder, loc, adjustBaseOffset),
118
- getValueOrCreateConstantIndexOp (builder, loc, linearizedSize),
119
- baseStrides.back ());
120
-
121
- return builder.create <memref::LoadOp>(
122
- loc, srcElementType, reinterpret.getResult (),
123
- getValueOrCreateConstantIndexOp (builder, loc, linearizedOffset));
124
- }
125
-
126
31
// / When data is loaded/stored in `targetBits` granularity, but is used in
127
32
// / `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
128
33
// / treated as an array of elements of width `sourceBits`.
@@ -239,8 +144,13 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
239
144
240
145
lastIdx = stridedMetadata.getOffset ();
241
146
} else {
242
- newLoad = linearizeMemrefLoad (loc, sourceType, srcBits, dstBits, indices,
243
- stridedMetadata, rewriter);
147
+ auto [reinterpret, linearizedOffset] =
148
+ memref::getLinearizeMemRefAndOffset (loc, sourceType, srcBits, dstBits,
149
+ adaptor.getIndices (),
150
+ stridedMetadata, rewriter);
151
+
152
+ newLoad = rewriter.create <memref::LoadOp>(loc, srcElementType,
153
+ reinterpret, linearizedOffset);
244
154
245
155
lastIdx = adaptor.getIndices ().back ();
246
156
}
0 commit comments