33
33
#include " mlir/Transforms/DialectConversion.h"
34
34
#include " llvm/ADT/SmallVector.h"
35
35
#include " llvm/Support/Debug.h"
36
+ #include " llvm/Support/LogicalResult.h"
36
37
#include " llvm/Support/MathExtras.h"
37
38
#include " llvm/Support/raw_ostream.h"
38
39
#include < cstdint>
@@ -143,13 +144,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
143
144
// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
144
145
// / emitting `vector.extract_strided_slice`.
145
146
static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
146
- VectorType extractType, Value source,
147
- int64_t frontOffset,
147
+ Value source, int64_t frontOffset,
148
148
int64_t subvecSize) {
149
149
auto vectorType = cast<VectorType>(source.getType ());
150
- assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
151
- " expected 1-D source and destination types" );
152
- (void )vectorType;
150
+ assert (vectorType.getRank () == 1 && " expected 1-D source types" );
153
151
assert (frontOffset + subvecSize <= vectorType.getNumElements () &&
154
152
" subvector out of bounds" );
155
153
@@ -160,9 +158,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
160
158
auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
161
159
auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
162
160
auto strides = rewriter.getI64ArrayAttr ({1 });
161
+
162
+ auto resultVectorType =
163
+ VectorType::get ({subvecSize}, vectorType.getElementType ());
163
164
return rewriter
164
- .create <vector::ExtractStridedSliceOp>(loc, extractType , source, offsets ,
165
- sizes, strides)
165
+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType , source,
166
+ offsets, sizes, strides)
166
167
->getResult (0 );
167
168
}
168
169
@@ -171,12 +172,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
171
172
// / `vector.insert_strided_slice`.
172
173
static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
173
174
Value src, Value dest, int64_t offset) {
174
- auto srcType = cast<VectorType>(src.getType ());
175
- auto destType = cast<VectorType>(dest.getType ());
175
+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
176
+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
176
177
assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
177
178
" expected source and dest to be vector type" );
178
- (void )srcType;
179
- (void )destType;
180
179
auto offsets = rewriter.getI64ArrayAttr ({offset});
181
180
auto strides = rewriter.getI64ArrayAttr ({1 });
182
181
return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
@@ -243,6 +242,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
243
242
newLoad);
244
243
}
245
244
245
+ static void nonAtomicStore (ConversionPatternRewriter &rewriter, Location loc,
246
+ Value memref, Value index, Value value) {
247
+ auto originType = dyn_cast<VectorType>(value.getType ());
248
+ auto memrefElemType = dyn_cast<MemRefType>(memref.getType ()).getElementType ();
249
+ auto scale = memrefElemType.getIntOrFloatBitWidth () /
250
+ originType.getElementType ().getIntOrFloatBitWidth ();
251
+ auto storeType =
252
+ VectorType::get ({originType.getNumElements () / scale}, memrefElemType);
253
+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType, value);
254
+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memref, index);
255
+ }
256
+
257
+ // / atomically store a subbyte-sized value to memory, with a mask.
258
+ static Value atomicStore (OpBuilder &rewriter, Location loc,
259
+ Value emulatedMemref, Value emulatedIndex,
260
+ TypedValue<VectorType> value, Value mask,
261
+ int64_t scale) {
262
+ auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
263
+ loc, emulatedMemref, ValueRange{emulatedIndex});
264
+ OpBuilder builder =
265
+ OpBuilder::atBlockEnd (atomicOp.getBody (), rewriter.getListener ());
266
+ Value origValue = atomicOp.getCurrentValue ();
267
+
268
+ // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
269
+ auto oneVectorType = VectorType::get ({1 }, origValue.getType ());
270
+ auto fromElem = builder.create <vector::FromElementsOp>(loc, oneVectorType,
271
+ ValueRange{origValue});
272
+ auto vectorBitCast =
273
+ builder.create <vector::BitCastOp>(loc, value.getType (), fromElem);
274
+
275
+ auto select =
276
+ builder.create <arith::SelectOp>(loc, mask, value, vectorBitCast);
277
+ auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
278
+ auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
279
+ builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
280
+ return atomicOp;
281
+ }
282
+
283
+ // Extract a slice of a vector, and insert it into a byte vector.
284
+ static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
285
+ Location loc, TypedValue<VectorType> vector,
286
+ int64_t sliceOffset, int64_t sliceNumElements,
287
+ int64_t byteOffset) {
288
+ auto vectorElementType = vector.getType ().getElementType ();
289
+ assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
290
+ " vector element must be a valid sub-byte type" );
291
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
292
+ auto emptyByteVector = rewriter.create <arith::ConstantOp>(
293
+ loc, VectorType::get ({scale}, vectorElementType),
294
+ rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
295
+ auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
296
+ sliceOffset, sliceNumElements);
297
+ auto inserted = staticallyInsertSubvector (rewriter, loc, extracted,
298
+ emptyByteVector, byteOffset);
299
+ return inserted;
300
+ }
301
+
246
302
namespace {
247
303
248
304
// ===----------------------------------------------------------------------===//
@@ -263,7 +319,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
263
319
264
320
auto loc = op.getLoc ();
265
321
auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
266
- Type oldElementType = op.getValueToStore ().getType ().getElementType ();
322
+ auto valueToStore = op.getValueToStore ();
323
+ Type oldElementType = valueToStore.getType ().getElementType ();
267
324
Type newElementType = convertedType.getElementType ();
268
325
int srcBits = oldElementType.getIntOrFloatBitWidth ();
269
326
int dstBits = newElementType.getIntOrFloatBitWidth ();
@@ -287,30 +344,124 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
287
344
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
288
345
// vector<4xi8>
289
346
290
- auto origElements = op.getValueToStore ().getType ().getNumElements ();
291
- if (origElements % scale != 0 )
292
- return failure ();
347
+ auto origElements = valueToStore.getType ().getNumElements ();
348
+ bool isUnalignedEmulation = origElements % scale != 0 ;
293
349
294
350
auto stridedMetadata =
295
351
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
296
352
297
353
OpFoldResult linearizedIndices;
298
- std::tie (std::ignore, linearizedIndices) =
354
+ memref::LinearizedMemRefInfo linearizedInfo;
355
+ std::tie (linearizedInfo, linearizedIndices) =
299
356
memref::getLinearizedMemRefOffsetAndSize (
300
357
rewriter, loc, srcBits, dstBits,
301
358
stridedMetadata.getConstifiedMixedOffset (),
302
359
stridedMetadata.getConstifiedMixedSizes (),
303
360
stridedMetadata.getConstifiedMixedStrides (),
304
361
getAsOpFoldResult (adaptor.getIndices ()));
305
362
306
- auto numElements = origElements / scale;
307
- auto bitCast = rewriter.create <vector::BitCastOp>(
308
- loc, VectorType::get (numElements, newElementType),
309
- op.getValueToStore ());
363
+ auto foldedIntraVectorOffset =
364
+ isUnalignedEmulation
365
+ ? getConstantIntValue (linearizedInfo.intraDataOffset )
366
+ : 0 ;
367
+
368
+ if (!foldedIntraVectorOffset) {
369
+ // unimplemented case for dynamic front padding size
370
+ return failure ();
371
+ }
372
+
373
+ // conditions when atomic stores and all that are not needed:
374
+ // 1. The source vector size is multiple of byte size
375
+ // 2. The address of the store is byte aligned
376
+ if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0 ) {
377
+ auto numElements = origElements / scale;
378
+ auto bitCast = rewriter.create <vector::BitCastOp>(
379
+ loc, VectorType::get (numElements, newElementType),
380
+ op.getValueToStore ());
381
+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
382
+ op, bitCast.getResult (), adaptor.getBase (),
383
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
384
+ return llvm::success ();
385
+ }
386
+
387
+ Value emulatedMemref = adaptor.getBase ();
388
+ // the index into the target memref we are storing to
389
+ Value currentDestIndex =
390
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
391
+ auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
392
+ auto atomicMaskType = VectorType::get ({scale}, rewriter.getI1Type ());
393
+ // the index into the source vector we are currently processing
394
+ auto currentSourceIndex = 0 ;
395
+
396
+ // 1. atomic store for the first byte
397
+ auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
398
+ if (frontAtomicStoreElem != 0 ) {
399
+ auto frontMaskValues = llvm::SmallVector<bool >(scale, false );
400
+ if (*foldedIntraVectorOffset + origElements < scale) {
401
+ std::fill_n (frontMaskValues.begin () + *foldedIntraVectorOffset,
402
+ origElements, true );
403
+ frontAtomicStoreElem = origElements;
404
+ } else {
405
+ std::fill_n (frontMaskValues.end () - frontAtomicStoreElem,
406
+ *foldedIntraVectorOffset, true );
407
+ }
408
+ auto frontMask = rewriter.create <arith::ConstantOp>(
409
+ loc, DenseElementsAttr::get (atomicMaskType, frontMaskValues));
410
+
411
+ currentSourceIndex = scale - (*foldedIntraVectorOffset);
412
+ auto value = extractSliceIntoByte (
413
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
414
+ frontAtomicStoreElem, *foldedIntraVectorOffset);
415
+
416
+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
417
+ cast<TypedValue<VectorType>>(value), frontMask.getResult (),
418
+ scale);
419
+
420
+ currentDestIndex = rewriter.create <arith::AddIOp>(
421
+ loc, rewriter.getIndexType (), currentDestIndex, constantOne);
422
+ }
423
+
424
+ if (currentSourceIndex >= origElements) {
425
+ rewriter.eraseOp (op);
426
+ return success ();
427
+ }
428
+
429
+ // 2. non-atomic store
430
+ int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
431
+ int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
432
+ if (nonAtomicStoreSize != 0 ) {
433
+ auto nonAtomicStorePart = staticallyExtractSubvector (
434
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
435
+ currentSourceIndex, numNonAtomicElements);
436
+
437
+ nonAtomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
438
+ nonAtomicStorePart);
439
+
440
+ currentSourceIndex += numNonAtomicElements;
441
+ currentDestIndex = rewriter.create <arith::AddIOp>(
442
+ loc, rewriter.getIndexType (), currentDestIndex,
443
+ rewriter.create <arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
444
+ }
445
+
446
+ // 3. atomic store for the last byte
447
+ auto remainingElements = origElements - currentSourceIndex;
448
+ if (remainingElements != 0 ) {
449
+ auto atomicStorePart = extractSliceIntoByte (
450
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
451
+ currentSourceIndex, remainingElements, 0 );
452
+
453
+ // back mask
454
+ auto maskValues = llvm::SmallVector<bool >(scale, 0 );
455
+ std::fill_n (maskValues.begin (), remainingElements, 1 );
456
+ auto backMask = rewriter.create <arith::ConstantOp>(
457
+ loc, DenseElementsAttr::get (atomicMaskType, maskValues));
458
+
459
+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
460
+ cast<TypedValue<VectorType>>(atomicStorePart),
461
+ backMask.getResult (), scale);
462
+ }
310
463
311
- rewriter.replaceOpWithNewOp <vector::StoreOp>(
312
- op, bitCast.getResult (), adaptor.getBase (),
313
- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
464
+ rewriter.eraseOp (op);
314
465
return success ();
315
466
}
316
467
};
@@ -518,9 +669,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
518
669
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
519
670
linearizedInfo.intraDataOffset , origElements);
520
671
} else if (isUnalignedEmulation) {
521
- result =
522
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
523
- *foldedIntraVectorOffset, origElements);
672
+ result = staticallyExtractSubvector (
673
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
524
674
}
525
675
rewriter.replaceOp (op, result);
526
676
return success ();
@@ -679,9 +829,8 @@ struct ConvertVectorMaskedLoad final
679
829
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
680
830
op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
681
831
} else if (isUnalignedEmulation) {
682
- result =
683
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
684
- *foldedIntraVectorOffset, origElements);
832
+ result = staticallyExtractSubvector (
833
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
685
834
}
686
835
rewriter.replaceOp (op, result);
687
836
@@ -764,9 +913,8 @@ struct ConvertVectorTransferRead final
764
913
linearizedInfo.intraDataOffset ,
765
914
origElements);
766
915
} else if (isUnalignedEmulation) {
767
- result =
768
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
769
- *foldedIntraVectorOffset, origElements);
916
+ result = staticallyExtractSubvector (
917
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
770
918
}
771
919
rewriter.replaceOp (op, result);
772
920
0 commit comments