@@ -35,18 +35,18 @@ using namespace mlir;
35
35
// / `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
36
36
// / located at (x % 2) * 4. Because there are two elements in one i8, and one
37
37
// / element has 4 bits.
38
- static Value getOffsetForBitwidth (Location loc, Value srcIdx, int sourceBits,
39
- int targetBits, OpBuilder &builder) {
38
+ static Value getOffsetForBitwidth (Location loc, OpFoldResult srcIdx,
39
+ int sourceBits, int targetBits,
40
+ OpBuilder &builder) {
40
41
assert (targetBits % sourceBits == 0 );
41
- IntegerType targetType = builder.getIntegerType (targetBits);
42
- IntegerAttr idxAttr =
43
- builder.getIntegerAttr (targetType, targetBits / sourceBits);
44
- auto idx = builder.create <arith::ConstantOp>(loc, targetType, idxAttr);
45
- IntegerAttr srcBitsAttr = builder.getIntegerAttr (targetType, sourceBits);
46
- auto srcBitsValue =
47
- builder.create <arith::ConstantOp>(loc, targetType, srcBitsAttr);
48
- auto m = builder.create <arith::RemUIOp>(loc, srcIdx, idx);
49
- return builder.create <arith::MulIOp>(loc, targetType, m, srcBitsValue);
42
+ AffineExpr s0;
43
+ bindSymbols (builder.getContext (), s0);
44
+ int scaleFactor = targetBits / sourceBits;
45
+ OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply (
46
+ builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
47
+ Value bitOffset = getValueOrCreateConstantIndexOp (builder, loc, offsetVal);
48
+ IntegerType dstType = builder.getIntegerType (targetBits);
49
+ return builder.create <arith::IndexCastOp>(loc, dstType, bitOffset);
50
50
}
51
51
52
52
namespace {
@@ -61,15 +61,43 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
61
61
LogicalResult
62
62
matchAndRewrite (memref::AllocOp op, OpAdaptor adaptor,
63
63
ConversionPatternRewriter &rewriter) const override {
64
- Type newTy = getTypeConverter ()->convertType (op.getType ());
65
- if (!newTy) {
64
+ auto currentType = op.getMemref ().getType ().cast <MemRefType>();
65
+ auto newResultType =
66
+ getTypeConverter ()->convertType (op.getType ()).dyn_cast <MemRefType>();
67
+ if (!newResultType) {
66
68
return rewriter.notifyMatchFailure (
67
69
op->getLoc (),
68
70
llvm::formatv (" failed to convert memref type: {0}" , op.getType ()));
69
71
}
70
72
73
+ // Special case zero-rank memrefs.
74
+ if (currentType.getRank () == 0 ) {
75
+ rewriter.replaceOpWithNewOp <memref::AllocOp>(
76
+ op, newResultType, ValueRange{}, adaptor.getSymbolOperands (),
77
+ adaptor.getAlignmentAttr ());
78
+ return success ();
79
+ }
80
+
81
+ Location loc = op.getLoc ();
82
+ OpFoldResult zero = rewriter.getIndexAttr (0 );
83
+ SmallVector<OpFoldResult> indices (currentType.getRank (), zero);
84
+
85
+ // Get linearized type.
86
+ int srcBits = currentType.getElementType ().getIntOrFloatBitWidth ();
87
+ int dstBits = newResultType.getElementType ().getIntOrFloatBitWidth ();
88
+ SmallVector<OpFoldResult> sizes = op.getMixedSizes ();
89
+
90
+ memref::LinearizedMemRefInfo linearizedMemRefInfo =
91
+ memref::getLinearizedMemRefOffsetAndSize (
92
+ rewriter, loc, srcBits, dstBits, /* offset =*/ zero, sizes);
93
+ SmallVector<Value> dynamicLinearizedSize;
94
+ if (!newResultType.hasStaticShape ()) {
95
+ dynamicLinearizedSize.push_back (getValueOrCreateConstantIndexOp (
96
+ rewriter, loc, linearizedMemRefInfo.linearizedSize ));
97
+ }
98
+
71
99
rewriter.replaceOpWithNewOp <memref::AllocOp>(
72
- op, newTy, adaptor. getDynamicSizes () , adaptor.getSymbolOperands (),
100
+ op, newResultType, dynamicLinearizedSize , adaptor.getSymbolOperands (),
73
101
adaptor.getAlignmentAttr ());
74
102
return success ();
75
103
}
@@ -109,73 +137,68 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
109
137
LogicalResult
110
138
matchAndRewrite (memref::LoadOp op, OpAdaptor adaptor,
111
139
ConversionPatternRewriter &rewriter) const override {
112
- Type newTy = getTypeConverter ()->convertType (op.getMemRefType ());
113
- if (!newTy) {
114
- return rewriter.notifyMatchFailure (
115
- op->getLoc (), llvm::formatv (" failed to convert memref type: {0}" ,
116
- op.getMemRefType ()));
117
- }
118
-
119
- if (op.getMemRefType () == newTy)
120
- return failure ();
121
-
122
- auto loc = op.getLoc ();
123
- auto sourceType = cast<MemRefType>(adaptor.getMemref ().getType ());
124
- unsigned sourceRank = sourceType.getRank ();
125
- SmallVector<Value> indices = adaptor.getIndices ();
126
- assert (indices.size () == sourceRank);
127
-
128
- auto srcElementType = sourceType.getElementType ();
140
+ auto convertedType = adaptor.getMemref ().getType ().cast <MemRefType>();
141
+ auto convertedElementType = convertedType.getElementType ();
129
142
auto oldElementType = op.getMemRefType ().getElementType ();
130
143
int srcBits = oldElementType.getIntOrFloatBitWidth ();
131
- int dstBits = srcElementType .getIntOrFloatBitWidth ();
144
+ int dstBits = convertedElementType .getIntOrFloatBitWidth ();
132
145
if (dstBits % srcBits != 0 ) {
133
146
return rewriter.notifyMatchFailure (
134
147
op, " only dstBits % srcBits == 0 supported" );
135
148
}
136
149
137
- auto stridedMetadata = rewriter.create <memref::ExtractStridedMetadataOp>(
138
- loc, adaptor.getMemref ());
139
-
140
- Value newLoad, lastIdx;
141
- if (sourceRank == 0 ) {
142
- newLoad = rewriter.create <memref::LoadOp>(
143
- loc, srcElementType, adaptor.getMemref (), adaptor.getIndices ());
144
-
145
- lastIdx = stridedMetadata.getOffset ();
150
+ Location loc = op.getLoc ();
151
+ // Special case 0-rank memref loads.
152
+ Value bitsLoad;
153
+ if (convertedType.getRank () == 0 ) {
154
+ bitsLoad = rewriter.create <memref::LoadOp>(loc, adaptor.getMemref (),
155
+ ValueRange{});
146
156
} else {
147
- auto [reinterpret, linearizedOffset] =
148
- memref::getLinearizeMemRefAndOffset (loc, sourceType, srcBits, dstBits,
149
- adaptor.getIndices (),
150
- stridedMetadata, rewriter);
151
-
152
- newLoad = rewriter.create <memref::LoadOp>(loc, srcElementType,
153
- reinterpret, linearizedOffset);
154
-
155
- lastIdx = adaptor.getIndices ().back ();
157
+ SmallVector<OpFoldResult> indices =
158
+ getAsOpFoldResult (adaptor.getIndices ());
159
+
160
+ auto stridedMetadata = rewriter.create <memref::ExtractStridedMetadataOp>(
161
+ loc, op.getMemRef ());
162
+
163
+ // Linearize the indices of the original load instruction. Do not account
164
+ // for the scaling yet. This will be accounted for later.
165
+ OpFoldResult linearizedIndices;
166
+ std::tie (std::ignore, linearizedIndices) =
167
+ memref::getLinearizedMemRefOffsetAndSize (
168
+ rewriter, loc, srcBits, srcBits,
169
+ stridedMetadata.getConstifiedMixedOffset (),
170
+ stridedMetadata.getConstifiedMixedSizes (),
171
+ stridedMetadata.getConstifiedMixedStrides (), indices);
172
+
173
+ AffineExpr s0;
174
+ bindSymbols (rewriter.getContext (), s0);
175
+ int64_t scaler = dstBits / srcBits;
176
+ OpFoldResult scaledLinearizedIndices =
177
+ affine::makeComposedFoldedAffineApply (
178
+ rewriter, loc, s0.floorDiv (scaler), {linearizedIndices});
179
+ Value newLoad = rewriter.create <memref::LoadOp>(
180
+ loc, adaptor.getMemref (),
181
+ getValueOrCreateConstantIndexOp (rewriter, loc,
182
+ scaledLinearizedIndices));
183
+
184
+ // Get the offset and shift the bits to the rightmost.
185
+ // Note, currently only the big-endian is supported.
186
+ Value bitwidthOffset = getOffsetForBitwidth (loc, linearizedIndices,
187
+ srcBits, dstBits, rewriter);
188
+ bitsLoad = rewriter.create <arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
156
189
}
157
190
158
- // Get the offset and shift the bits to the rightmost.
159
- // Note, currently only the big-endian is supported.
160
- auto castLastIdx =
161
- rewriter.create <arith::IndexCastUIOp>(loc, srcElementType, lastIdx);
162
-
163
- Value BitwidthOffset =
164
- getOffsetForBitwidth (loc, castLastIdx, srcBits, dstBits, rewriter);
165
- auto bitsLoad =
166
- rewriter.create <arith::ShRSIOp>(loc, newLoad, BitwidthOffset);
167
-
168
191
// Get the corresponding bits. If the arith computation bitwidth equals
169
192
// to the emulated bitwidth, we apply a mask to extract the low bits.
170
193
// It is not clear if this case actually happens in practice, but we keep
171
194
// the operations just in case. Otherwise, if the arith computation bitwidth
172
195
// is different from the emulated bitwidth we truncate the result.
173
196
Operation *result;
174
197
auto resultTy = getTypeConverter ()->convertType (oldElementType);
175
- if (resultTy == srcElementType ) {
198
+ if (resultTy == convertedElementType ) {
176
199
auto mask = rewriter.create <arith::ConstantOp>(
177
- loc, srcElementType ,
178
- rewriter.getIntegerAttr (srcElementType , (1 << srcBits) - 1 ));
200
+ loc, convertedElementType ,
201
+ rewriter.getIntegerAttr (convertedElementType , (1 << srcBits) - 1 ));
179
202
180
203
result = rewriter.create <arith::AndIOp>(loc, bitsLoad, mask);
181
204
} else {
@@ -200,6 +223,25 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
200
223
patterns
201
224
.add <ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
202
225
typeConverter, patterns.getContext ());
226
+ memref::populateResolveExtractStridedMetadataPatterns (patterns);
227
+ }
228
+
229
+ static SmallVector<int64_t > getLinearizedShape (MemRefType ty, int srcBits,
230
+ int dstBits) {
231
+ if (ty.getRank () == 0 )
232
+ return {};
233
+
234
+ int64_t linearizedShape = 1 ;
235
+ for (auto shape : ty.getShape ()) {
236
+ if (shape == ShapedType::kDynamic )
237
+ return {ShapedType::kDynamic };
238
+ linearizedShape *= shape;
239
+ }
240
+ int scale = dstBits / srcBits;
241
+ // Scale the size to the ceilDiv(linearizedShape, scale)
242
+ // to accomodate all the values.
243
+ linearizedShape = (linearizedShape + scale - 1 ) / scale;
244
+ return {linearizedShape};
203
245
}
204
246
205
247
void memref::populateMemRefNarrowTypeEmulationConversions (
@@ -215,11 +257,26 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
215
257
if (width >= loadStoreWidth)
216
258
return ty;
217
259
260
+ // Currently only handle innermost stride being 1, checking
261
+ SmallVector<int64_t > strides;
262
+ int64_t offset;
263
+ if (failed (getStridesAndOffset (ty, strides, offset)))
264
+ return std::nullopt;
265
+ if (!strides.empty () && strides.back () != 1 )
266
+ return std::nullopt;
267
+
218
268
auto newElemTy = IntegerType::get (ty.getContext (), loadStoreWidth,
219
269
intTy.getSignedness ());
220
270
if (!newElemTy)
221
271
return std::nullopt;
222
272
223
- return ty.cloneWith (std::nullopt, newElemTy);
273
+ StridedLayoutAttr layoutAttr;
274
+ if (offset != 0 ) {
275
+ layoutAttr = StridedLayoutAttr::get (ty.getContext (), offset,
276
+ ArrayRef<int64_t >{1 });
277
+ }
278
+
279
+ return MemRefType::get (getLinearizedShape (ty, width, loadStoreWidth),
280
+ newElemTy, layoutAttr, ty.getMemorySpace ());
224
281
});
225
282
}
0 commit comments