33
33
#include " mlir/Transforms/DialectConversion.h"
34
34
#include " llvm/ADT/SmallVector.h"
35
35
#include " llvm/Support/Debug.h"
36
+ #include " llvm/Support/LogicalResult.h"
36
37
#include " llvm/Support/MathExtras.h"
37
38
#include " llvm/Support/raw_ostream.h"
38
39
#include < cstdint>
@@ -143,19 +144,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
143
144
// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
144
145
// / emitting `vector.extract_strided_slice`.
145
146
static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
146
- VectorType extractType, Value source,
147
- int64_t frontOffset,
147
+ Value source, int64_t frontOffset,
148
148
int64_t subvecSize) {
149
- auto vectorType = cast<VectorType>(source.getType ());
150
- assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
151
- " expected 1-D source and destination types" );
152
- (void )vectorType;
149
+ auto vectorType = llvm::cast<VectorType>(source.getType ());
150
+ assert (vectorType.getRank () == 1 && " expected 1-D source types" );
153
151
auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
154
152
auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
155
153
auto strides = rewriter.getI64ArrayAttr ({1 });
154
+
155
+ auto resultVectorType =
156
+ VectorType::get ({subvecSize}, vectorType.getElementType ());
156
157
return rewriter
157
- .create <vector::ExtractStridedSliceOp>(loc, extractType , source, offsets ,
158
- sizes, strides)
158
+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType , source,
159
+ offsets, sizes, strides)
159
160
->getResult (0 );
160
161
}
161
162
@@ -164,12 +165,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
164
165
// / `vector.insert_strided_slice`.
165
166
static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
166
167
Value src, Value dest, int64_t offset) {
167
- auto srcType = cast<VectorType>(src.getType ());
168
- auto destType = cast<VectorType>(dest.getType ());
168
+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
169
+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
169
170
assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
170
171
" expected source and dest to be vector type" );
171
- (void )srcType;
172
- (void )destType;
173
172
auto offsets = rewriter.getI64ArrayAttr ({offset});
174
173
auto strides = rewriter.getI64ArrayAttr ({1 });
175
174
return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
@@ -236,6 +235,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
236
235
newLoad);
237
236
}
238
237
238
+ static void nonAtomicStore (ConversionPatternRewriter &rewriter, Location loc,
239
+ Value memref, Value index, Value value) {
240
+ auto originType = dyn_cast<VectorType>(value.getType ());
241
+ auto memrefElemType = dyn_cast<MemRefType>(memref.getType ()).getElementType ();
242
+ auto scale = memrefElemType.getIntOrFloatBitWidth () /
243
+ originType.getElementType ().getIntOrFloatBitWidth ();
244
+ auto storeType =
245
+ VectorType::get ({originType.getNumElements () / scale}, memrefElemType);
246
+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType, value);
247
+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memref, index);
248
+ }
249
+
250
+ // / atomically store a subbyte-sized value to memory, with a mask.
251
+ static Value atomicStore (OpBuilder &rewriter, Location loc,
252
+ Value emulatedMemref, Value emulatedIndex,
253
+ TypedValue<VectorType> value, Value mask,
254
+ int64_t scale) {
255
+ auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
256
+ loc, emulatedMemref, ValueRange{emulatedIndex});
257
+ OpBuilder builder =
258
+ OpBuilder::atBlockEnd (atomicOp.getBody (), rewriter.getListener ());
259
+ Value origValue = atomicOp.getCurrentValue ();
260
+
261
+ // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
262
+ auto oneVectorType = VectorType::get ({1 }, origValue.getType ());
263
+ auto fromElem = builder.create <vector::FromElementsOp>(loc, oneVectorType,
264
+ ValueRange{origValue});
265
+ auto vectorBitCast =
266
+ builder.create <vector::BitCastOp>(loc, value.getType (), fromElem);
267
+
268
+ auto select =
269
+ builder.create <arith::SelectOp>(loc, mask, value, vectorBitCast);
270
+ auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
271
+ auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
272
+ builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
273
+ return atomicOp;
274
+ }
275
+
276
+ // Extract a slice of a vector, and insert it into a byte vector.
277
+ static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
278
+ Location loc, TypedValue<VectorType> vector,
279
+ int64_t sliceOffset, int64_t sliceNumElements,
280
+ int64_t byteOffset) {
281
+ auto vectorElementType = vector.getType ().getElementType ();
282
+ assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
283
+ " vector element must be a valid sub-byte type" );
284
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
285
+ auto emptyByteVector = rewriter.create <arith::ConstantOp>(
286
+ loc, VectorType::get ({scale}, vectorElementType),
287
+ rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
288
+ auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
289
+ sliceOffset, sliceNumElements);
290
+ auto inserted = staticallyInsertSubvector (rewriter, loc, extracted,
291
+ emptyByteVector, byteOffset);
292
+ return inserted;
293
+ }
294
+
239
295
namespace {
240
296
241
297
// ===----------------------------------------------------------------------===//
@@ -256,7 +312,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
256
312
257
313
auto loc = op.getLoc ();
258
314
auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
259
- Type oldElementType = op.getValueToStore ().getType ().getElementType ();
315
+ auto valueToStore = op.getValueToStore ();
316
+ Type oldElementType = valueToStore.getType ().getElementType ();
260
317
Type newElementType = convertedType.getElementType ();
261
318
int srcBits = oldElementType.getIntOrFloatBitWidth ();
262
319
int dstBits = newElementType.getIntOrFloatBitWidth ();
@@ -280,30 +337,121 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
280
337
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
281
338
// vector<4xi8>
282
339
283
- auto origElements = op.getValueToStore ().getType ().getNumElements ();
284
- if (origElements % scale != 0 )
285
- return failure ();
340
+ auto origElements = valueToStore.getType ().getNumElements ();
341
+ bool isUnalignedEmulation = origElements % scale != 0 ;
286
342
287
343
auto stridedMetadata =
288
344
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
289
345
290
346
OpFoldResult linearizedIndices;
291
- std::tie (std::ignore, linearizedIndices) =
347
+ memref::LinearizedMemRefInfo linearizedInfo;
348
+ std::tie (linearizedInfo, linearizedIndices) =
292
349
memref::getLinearizedMemRefOffsetAndSize (
293
350
rewriter, loc, srcBits, dstBits,
294
351
stridedMetadata.getConstifiedMixedOffset (),
295
352
stridedMetadata.getConstifiedMixedSizes (),
296
353
stridedMetadata.getConstifiedMixedStrides (),
297
354
getAsOpFoldResult (adaptor.getIndices ()));
298
355
299
- auto numElements = origElements / scale;
300
- auto bitCast = rewriter.create <vector::BitCastOp>(
301
- loc, VectorType::get (numElements, newElementType),
302
- op.getValueToStore ());
356
+ auto foldedIntraVectorOffset =
357
+ isUnalignedEmulation
358
+ ? getConstantIntValue (linearizedInfo.intraDataOffset )
359
+ : 0 ;
360
+
361
+ if (!foldedIntraVectorOffset) {
362
+ // unimplemented case for dynamic front padding size
363
+ return failure ();
364
+ }
365
+
366
+ if (!isUnalignedEmulation) {
367
+ auto numElements = origElements / scale;
368
+ auto bitCast = rewriter.create <vector::BitCastOp>(
369
+ loc, VectorType::get (numElements, newElementType),
370
+ op.getValueToStore ());
371
+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
372
+ op, bitCast.getResult (), adaptor.getBase (),
373
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
374
+ return llvm::success ();
375
+ }
376
+
377
+ Value emulatedMemref = adaptor.getBase ();
378
+ // the index into the target memref we are storing to
379
+ Value currentDestIndex =
380
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
381
+ auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
382
+ auto atomicMaskType = VectorType::get ({scale}, rewriter.getI1Type ());
383
+ // the index into the source vector we are currently processing
384
+ auto currentSourceIndex = 0 ;
385
+
386
+ // 1. atomic store for the first byte
387
+ auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
388
+ if (frontAtomicStoreElem != 0 ) {
389
+ auto frontMaskValues = llvm::SmallVector<bool >(scale, false );
390
+ if (*foldedIntraVectorOffset + origElements < scale) {
391
+ std::fill_n (frontMaskValues.begin () + *foldedIntraVectorOffset,
392
+ origElements, true );
393
+ frontAtomicStoreElem = origElements;
394
+ } else {
395
+ std::fill_n (frontMaskValues.end () - frontAtomicStoreElem,
396
+ *foldedIntraVectorOffset, true );
397
+ }
398
+ auto frontMask = rewriter.create <arith::ConstantOp>(
399
+ loc, DenseElementsAttr::get (atomicMaskType, frontMaskValues));
400
+
401
+ currentSourceIndex = scale - (*foldedIntraVectorOffset);
402
+ auto value = extractSliceIntoByte (
403
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
404
+ frontAtomicStoreElem, *foldedIntraVectorOffset);
405
+
406
+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
407
+ cast<TypedValue<VectorType>>(value), frontMask.getResult (),
408
+ scale);
409
+
410
+ currentDestIndex = rewriter.create <arith::AddIOp>(
411
+ loc, rewriter.getIndexType (), currentDestIndex, constantOne);
412
+ }
413
+
414
+ if (currentSourceIndex >= origElements) {
415
+ rewriter.eraseOp (op);
416
+ return success ();
417
+ }
418
+
419
+ // 2. non-atomic store
420
+ int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
421
+ int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
422
+ if (nonAtomicStoreSize != 0 ) {
423
+ auto nonAtomicStorePart = staticallyExtractSubvector (
424
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
425
+ currentSourceIndex, numNonAtomicElements);
426
+
427
+ nonAtomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
428
+ nonAtomicStorePart);
429
+
430
+ currentSourceIndex += numNonAtomicElements;
431
+ currentDestIndex = rewriter.create <arith::AddIOp>(
432
+ loc, rewriter.getIndexType (), currentDestIndex,
433
+ rewriter.create <arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
434
+ }
435
+
436
+ // 3. atomic store for the last byte
437
+ auto remainingElements = origElements - currentSourceIndex;
438
+ if (remainingElements != 0 ) {
439
+ auto atomicStorePart = extractSliceIntoByte (
440
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
441
+ currentSourceIndex, remainingElements, 0 );
442
+
443
+ // back mask
444
+ auto maskValues = llvm::SmallVector<bool >(scale, 0 );
445
+ std::fill_n (maskValues.begin (), remainingElements, 1 );
446
+ auto backMask = rewriter.create <arith::ConstantOp>(
447
+ loc, DenseElementsAttr::get (atomicMaskType, maskValues));
448
+
449
+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
450
+ cast<TypedValue<VectorType>>(atomicStorePart),
451
+ backMask.getResult (), scale);
452
+ }
303
453
304
- rewriter.replaceOpWithNewOp <vector::StoreOp>(
305
- op, bitCast.getResult (), adaptor.getBase (),
306
- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
454
+ rewriter.eraseOp (op);
307
455
return success ();
308
456
}
309
457
};
@@ -511,9 +659,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
511
659
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
512
660
linearizedInfo.intraDataOffset , origElements);
513
661
} else if (isUnalignedEmulation) {
514
- result =
515
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
516
- *foldedIntraVectorOffset, origElements);
662
+ result = staticallyExtractSubvector (
663
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
517
664
}
518
665
rewriter.replaceOp (op, result);
519
666
return success ();
@@ -672,9 +819,8 @@ struct ConvertVectorMaskedLoad final
672
819
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
673
820
op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
674
821
} else if (isUnalignedEmulation) {
675
- result =
676
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
677
- *foldedIntraVectorOffset, origElements);
822
+ result = staticallyExtractSubvector (
823
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
678
824
}
679
825
rewriter.replaceOp (op, result);
680
826
@@ -757,9 +903,8 @@ struct ConvertVectorTransferRead final
757
903
linearizedInfo.intraDataOffset ,
758
904
origElements);
759
905
} else if (isUnalignedEmulation) {
760
- result =
761
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
762
- *foldedIntraVectorOffset, origElements);
906
+ result = staticallyExtractSubvector (
907
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
763
908
}
764
909
rewriter.replaceOp (op, result);
765
910
0 commit comments