24
24
#include " llvm/Support/Debug.h"
25
25
#include " llvm/Support/raw_ostream.h"
26
26
#include < cstdint>
27
+ #include < optional>
27
28
28
29
using namespace mlir ;
29
30
@@ -102,6 +103,23 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
102
103
return newMask;
103
104
}
104
105
106
+ // /
107
+ static std::optional<int64_t >
108
+ getFrontPaddingSize (ConversionPatternRewriter &rewriter, Location loc,
109
+ const memref::LinearizedMemRefInfo linearizedInfo,
110
+ bool isUnalignedEmulation) {
111
+ if (!isUnalignedEmulation)
112
+ return 0 ;
113
+ auto foldedFrontPaddingSize = getValueOrCreateConstantIndexOp (
114
+ rewriter, loc, linearizedInfo.frontPaddingSize );
115
+ // try to fold the front padding size into a constant
116
+ if (auto frontPadding = dyn_cast_or_null<arith::ConstantIndexOp>(
117
+ foldedFrontPaddingSize.getDefiningOp ())) {
118
+ return frontPadding.value ();
119
+ }
120
+ return std::nullopt;
121
+ }
122
+
105
123
namespace {
106
124
107
125
// ===----------------------------------------------------------------------===//
@@ -142,29 +160,66 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
142
160
// vector<4xi8>
143
161
144
162
auto origElements = op.getValueToStore ().getType ().getNumElements ();
145
- if (origElements % scale != 0 )
146
- return failure ();
163
+
164
+ // if the size of vector we are loading is not byte-aligned, extra handling
165
+ // is needed
166
+ bool isUnalignedEmulation = origElements % scale != 0 ;
147
167
148
168
auto stridedMetadata =
149
169
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
150
170
151
171
OpFoldResult linearizedIndices;
152
- std::tie (std::ignore, linearizedIndices) =
172
+ memref::LinearizedMemRefInfo linearizedInfo;
173
+ std::tie (linearizedInfo, linearizedIndices) =
153
174
memref::getLinearizedMemRefOffsetAndSize (
154
175
rewriter, loc, srcBits, dstBits,
155
176
stridedMetadata.getConstifiedMixedOffset (),
156
177
stridedMetadata.getConstifiedMixedSizes (),
157
178
stridedMetadata.getConstifiedMixedStrides (),
158
179
getAsOpFoldResult (adaptor.getIndices ()));
159
180
160
- auto numElements = origElements / scale;
161
- auto bitCast = rewriter.create <vector::BitCastOp>(
162
- loc, VectorType::get (numElements, newElementType),
163
- op.getValueToStore ());
181
+ auto foldedFrontPaddingSize = getFrontPaddingSize (
182
+ rewriter, loc, linearizedInfo, isUnalignedEmulation);
164
183
165
- rewriter.replaceOpWithNewOp <vector::StoreOp>(
166
- op, bitCast.getResult (), adaptor.getBase (),
167
- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
184
+ if (!foldedFrontPaddingSize) {
185
+ // unimplemented case for dynamic front padding size
186
+ return failure ();
187
+ }
188
+
189
+ auto numElements =
190
+ (*foldedFrontPaddingSize + origElements + scale - 1 ) / scale;
191
+ auto newVectorType = VectorType::get (numElements, newElementType);
192
+
193
+ if (isUnalignedEmulation) {
194
+ auto insertedVectorType =
195
+ VectorType::get (numElements * scale, oldElementType);
196
+
197
+ auto linearizedIndicesValue =
198
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
199
+ auto passThru =
200
+ rewriter.create <vector::LoadOp>(loc, newVectorType, adaptor.getBase (),
201
+ ValueRange{linearizedIndicesValue});
202
+ auto bitcastedPassThru =
203
+ rewriter.create <vector::BitCastOp>(loc, insertedVectorType, passThru);
204
+
205
+ // just extract it and use it for the strided slice offset
206
+ auto insertStridedSlice = rewriter.create <vector::InsertStridedSliceOp>(
207
+ loc, insertedVectorType, op.getValueToStore (), bitcastedPassThru,
208
+ rewriter.getI64ArrayAttr ({*foldedFrontPaddingSize}),
209
+ rewriter.getI64ArrayAttr ({1 }));
210
+ // bit cast the vector to the original type
211
+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, newVectorType,
212
+ insertStridedSlice);
213
+
214
+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
215
+ op, bitCast.getResult (), adaptor.getBase (), linearizedIndicesValue);
216
+ } else {
217
+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, newVectorType,
218
+ op.getValueToStore ());
219
+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
220
+ op, bitCast.getResult (), adaptor.getBase (),
221
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
222
+ }
168
223
return success ();
169
224
}
170
225
};
@@ -294,35 +349,67 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
294
349
// %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
295
350
// %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
296
351
//
297
- // TODO: Currently, only the even number of elements loading is supported.
298
- // To deal with the odd number of elements, one has to extract the
299
- // subvector at the proper offset after bit-casting.
352
+ // There are cases where the number of elements to load is not byte-aligned,
353
+ // for example:
354
+ //
355
+ // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
356
+ //
357
+ // we will have to load extra bytes and extract the exact slice in between.
358
+ //
359
+ // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
360
+ // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
361
+ // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
362
+ // = [1]}
363
+ // : vector<8xi2> to vector<3xi2>
364
+ //
365
+ // TODO: Currently the extract_strided_slice's attributes must be known at
366
+ // compile time as they must be constants.
300
367
301
368
auto origElements = op.getVectorType ().getNumElements ();
302
- if (origElements % scale != 0 )
303
- return failure ();
369
+ bool isUnalignedEmulation = origElements % scale != 0 ;
304
370
305
371
auto stridedMetadata =
306
372
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
307
373
308
374
OpFoldResult linearizedIndices;
309
- std::tie (std::ignore, linearizedIndices) =
375
+ memref::LinearizedMemRefInfo linearizedInfo;
376
+ std::tie (linearizedInfo, linearizedIndices) =
310
377
memref::getLinearizedMemRefOffsetAndSize (
311
378
rewriter, loc, srcBits, dstBits,
312
379
stridedMetadata.getConstifiedMixedOffset (),
313
380
stridedMetadata.getConstifiedMixedSizes (),
314
381
stridedMetadata.getConstifiedMixedStrides (),
315
382
getAsOpFoldResult (adaptor.getIndices ()));
316
383
317
- auto numElements = (origElements + scale - 1 ) / scale;
384
+ auto foldedFrontPaddingSize = getFrontPaddingSize (
385
+ rewriter, loc, linearizedInfo, isUnalignedEmulation);
386
+
387
+ if (!foldedFrontPaddingSize) {
388
+ // unimplemented case for dynamic front padding size
389
+ return failure ();
390
+ }
391
+
392
+ auto numElements =
393
+ (*foldedFrontPaddingSize + origElements + scale - 1 ) / scale;
394
+ auto loadVectorType = VectorType::get (numElements, newElementType);
318
395
auto newLoad = rewriter.create <vector::LoadOp>(
319
- loc, VectorType::get (numElements, newElementType) , adaptor.getBase (),
396
+ loc, loadVectorType , adaptor.getBase (),
320
397
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
321
398
399
+ auto newBitCastType = VectorType::get (numElements * scale, oldElementType);
322
400
auto bitCast =
323
- rewriter.create <vector::BitCastOp>(loc, op.getType (), newLoad);
324
-
325
- rewriter.replaceOp (op, bitCast->getResult (0 ));
401
+ rewriter.create <vector::BitCastOp>(loc, newBitCastType, newLoad);
402
+
403
+ if (newBitCastType.getNumElements () != origElements) {
404
+ auto extractStridedSlice = rewriter.create <vector::ExtractStridedSliceOp>(
405
+ loc, op.getType (), bitCast,
406
+ rewriter.getI64ArrayAttr ({*foldedFrontPaddingSize}),
407
+ rewriter.getI64ArrayAttr ({origElements}),
408
+ rewriter.getI64ArrayAttr ({1 }));
409
+ rewriter.replaceOp (op, extractStridedSlice.getResult ());
410
+ } else {
411
+ rewriter.replaceOp (op, bitCast->getResult (0 ));
412
+ }
326
413
return success ();
327
414
}
328
415
};
@@ -464,8 +551,8 @@ struct ConvertVectorTransferRead final
464
551
int scale = dstBits / srcBits;
465
552
466
553
auto origElements = op.getVectorType ().getNumElements ();
467
- if (origElements % scale != 0 )
468
- return failure () ;
554
+
555
+ bool isUnalignedEmulation = origElements % scale != 0 ;
469
556
470
557
auto newPadding = rewriter.create <arith::ExtUIOp>(loc, newElementType,
471
558
adaptor.getPadding ());
@@ -474,26 +561,47 @@ struct ConvertVectorTransferRead final
474
561
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getSource ());
475
562
476
563
OpFoldResult linearizedIndices;
477
- std::tie (std::ignore, linearizedIndices) =
564
+ memref::LinearizedMemRefInfo linearizedInfo;
565
+ std::tie (linearizedInfo, linearizedIndices) =
478
566
memref::getLinearizedMemRefOffsetAndSize (
479
567
rewriter, loc, srcBits, dstBits,
480
568
stridedMetadata.getConstifiedMixedOffset (),
481
569
stridedMetadata.getConstifiedMixedSizes (),
482
570
stridedMetadata.getConstifiedMixedStrides (),
483
571
getAsOpFoldResult (adaptor.getIndices ()));
484
572
485
- auto numElements = (origElements + scale - 1 ) / scale;
573
+ auto foldedFrontPaddingSize = getFrontPaddingSize (
574
+ rewriter, loc, linearizedInfo, isUnalignedEmulation);
575
+
576
+ if (!foldedFrontPaddingSize) {
577
+ // unimplemented case for dynamic front padding size
578
+ return failure ();
579
+ }
580
+
581
+ auto numElements =
582
+ (*foldedFrontPaddingSize + origElements + scale - 1 ) / scale;
486
583
auto newReadType = VectorType::get (numElements, newElementType);
487
584
488
585
auto newRead = rewriter.create <vector::TransferReadOp>(
489
586
loc, newReadType, adaptor.getSource (),
490
587
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices),
491
588
newPadding);
492
589
590
+ auto bitCastType = VectorType::get (numElements * scale, oldElementType);
493
591
auto bitCast =
494
- rewriter.create <vector::BitCastOp>(loc, op.getType (), newRead);
592
+ rewriter.create <vector::BitCastOp>(loc, bitCastType, newRead);
593
+
594
+ if (isUnalignedEmulation) {
595
+ // we only extract a portion of the vector.
596
+ rewriter.replaceOpWithNewOp <vector::ExtractStridedSliceOp>(
597
+ op, op.getType (), bitCast,
598
+ rewriter.getI64ArrayAttr ({*foldedFrontPaddingSize}),
599
+ rewriter.getI64ArrayAttr ({origElements}),
600
+ rewriter.getI64ArrayAttr ({1 }));
601
+ } else {
602
+ rewriter.replaceOp (op, bitCast->getResult (0 ));
603
+ }
495
604
496
- rewriter.replaceOp (op, bitCast->getResult (0 ));
497
605
return success ();
498
606
}
499
607
};
0 commit comments