@@ -52,7 +52,9 @@ using namespace mlir;
52
52
// /
53
53
// / %mask = [1, 1, 0, 0, 0, 0]
54
54
// /
55
- // / will first be padded with number of `intraDataOffset` zeros:
55
+ // / will first be padded in the front with number of `intraDataOffset` zeros,
56
+ // / and pad zeros in the back to make the number of elements a multiple of
57
+ // / `scale` (just to make it easier to compute). The new mask will be:
56
58
// / %mask = [0, 1, 1, 0, 0, 0, 0, 0]
57
59
// /
58
60
// / then it will return the following new compressed mask:
@@ -62,7 +64,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
62
64
Location loc, Value mask,
63
65
int origElements, int scale,
64
66
int intraDataOffset = 0 ) {
65
- auto numElements = (intraDataOffset + origElements + scale - 1 ) / scale;
67
+ assert (intraDataOffset < scale && " intraDataOffset must be less than scale" );
68
+ auto numElements = llvm::divideCeil (intraDataOffset + origElements, scale);
66
69
67
70
Operation *maskOp = mask.getDefiningOp ();
68
71
SmallVector<vector::ExtractOp, 2 > extractOps;
@@ -194,6 +197,26 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
194
197
return dest;
195
198
}
196
199
200
+ // / Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
201
+ static Value dynamicallyInsertSubVector (RewriterBase &rewriter, Location loc,
202
+ TypedValue<VectorType> source,
203
+ Value dest, OpFoldResult destOffsetVar,
204
+ size_t length) {
205
+ assert (length > 0 && " length must be greater than 0" );
206
+ Value destOffsetVal =
207
+ getValueOrCreateConstantIndexOp (rewriter, loc, destOffsetVar);
208
+ for (size_t i = 0 ; i < length; ++i) {
209
+ auto insertLoc = i == 0
210
+ ? destOffsetVal
211
+ : rewriter.create <arith::AddIOp>(
212
+ loc, rewriter.getIndexType (), destOffsetVal,
213
+ rewriter.create <arith::ConstantIndexOp>(loc, i));
214
+ auto extractOp = rewriter.create <vector::ExtractOp>(loc, source, i);
215
+ dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, insertLoc);
216
+ }
217
+ return dest;
218
+ }
219
+
197
220
// / Returns the op sequence for an emulated sub-byte data type vector load.
198
221
// / specifically, use `emulatedElemType` for loading a vector of `origElemType`.
199
222
// / The load location is given by `base` and `linearizedIndices`, and the
@@ -466,18 +489,16 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
466
489
emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
467
490
numElements, oldElementType, newElementType);
468
491
469
- if (foldedIntraVectorOffset) {
470
- if (isUnalignedEmulation) {
471
- result =
472
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
473
- *foldedIntraVectorOffset, origElements);
474
- }
475
- } else {
492
+ if (!foldedIntraVectorOffset) {
476
493
auto resultVector = rewriter.create <arith::ConstantOp>(
477
494
loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
478
495
result = dynamicallyExtractSubVector (
479
496
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
480
497
linearizedInfo.intraDataOffset , origElements);
498
+ } else if (isUnalignedEmulation) {
499
+ result =
500
+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
501
+ *foldedIntraVectorOffset, origElements);
481
502
}
482
503
rewriter.replaceOp (op, result);
483
504
return success ();
@@ -572,27 +593,26 @@ struct ConvertVectorMaskedLoad final
572
593
? getConstantIntValue (linearizedInfo.intraDataOffset )
573
594
: 0 ;
574
595
575
- if (!foldedIntraVectorOffset) {
576
- // unimplemented case for dynamic intra vector offset
577
- return failure ();
578
- }
579
-
580
- FailureOr<Operation *> newMask =
581
- getCompressedMaskOp (rewriter, loc, op.getMask (), origElements, scale,
582
- *foldedIntraVectorOffset);
596
+ int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
597
+ FailureOr<Operation *> newMask = getCompressedMaskOp (
598
+ rewriter, loc, op.getMask (), origElements, scale, maxIntraDataOffset);
583
599
if (failed (newMask))
584
600
return failure ();
585
601
602
+ Value passthru = op.getPassThru ();
603
+
586
604
auto numElements =
587
- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
605
+ llvm::divideCeil (maxIntraDataOffset + origElements, scale);
588
606
auto loadType = VectorType::get (numElements, newElementType);
589
607
auto newBitcastType = VectorType::get (numElements * scale, oldElementType);
590
608
591
- Value passthru = op.getPassThru ();
592
- if (isUnalignedEmulation) {
593
- // create an empty vector of the new type
594
- auto emptyVector = rewriter.create <arith::ConstantOp>(
595
- loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
609
+ auto emptyVector = rewriter.create <arith::ConstantOp>(
610
+ loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
611
+ if (!foldedIntraVectorOffset) {
612
+ passthru = dynamicallyInsertSubVector (
613
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
614
+ emptyVector, linearizedInfo.intraDataOffset , origElements);
615
+ } else if (isUnalignedEmulation) {
596
616
passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
597
617
*foldedIntraVectorOffset);
598
618
}
@@ -611,20 +631,27 @@ struct ConvertVectorMaskedLoad final
611
631
rewriter.create <vector::BitCastOp>(loc, newBitcastType, newLoad);
612
632
613
633
Value mask = op.getMask ();
614
- if (isUnalignedEmulation) {
615
- auto newSelectMaskType =
616
- VectorType::get (numElements * scale, rewriter.getI1Type ());
617
- // TODO: can fold if op's mask is constant
618
- auto emptyVector = rewriter.create <arith::ConstantOp>(
619
- loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
620
- mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyVector,
634
+ auto newSelectMaskType =
635
+ VectorType::get (numElements * scale, rewriter.getI1Type ());
636
+ // TODO: try to fold if op's mask is constant
637
+ auto emptyMask = rewriter.create <arith::ConstantOp>(
638
+ loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
639
+ if (!foldedIntraVectorOffset) {
640
+ mask = dynamicallyInsertSubVector (
641
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
642
+ linearizedInfo.intraDataOffset , origElements);
643
+ } else if (isUnalignedEmulation) {
644
+ mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
621
645
*foldedIntraVectorOffset);
622
646
}
623
647
624
648
Value result =
625
649
rewriter.create <arith::SelectOp>(loc, mask, bitCast, passthru);
626
-
627
- if (isUnalignedEmulation) {
650
+ if (!foldedIntraVectorOffset) {
651
+ result = dynamicallyExtractSubVector (
652
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
653
+ op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
654
+ } else if (isUnalignedEmulation) {
628
655
result =
629
656
staticallyExtractSubvector (rewriter, loc, op.getType (), result,
630
657
*foldedIntraVectorOffset, origElements);
@@ -685,10 +712,9 @@ struct ConvertVectorTransferRead final
685
712
? getConstantIntValue (linearizedInfo.intraDataOffset )
686
713
: 0 ;
687
714
688
- auto maxIntraVectorOffset =
689
- foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1 ;
715
+ int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
690
716
auto numElements =
691
- llvm::divideCeil (maxIntraVectorOffset + origElements, scale);
717
+ llvm::divideCeil (maxIntraDataOffset + origElements, scale);
692
718
693
719
auto newRead = rewriter.create <vector::TransferReadOp>(
694
720
loc, VectorType::get (numElements, newElementType), adaptor.getSource (),
@@ -699,18 +725,16 @@ struct ConvertVectorTransferRead final
699
725
loc, VectorType::get (numElements * scale, oldElementType), newRead);
700
726
701
727
Value result = bitCast->getResult (0 );
702
- if (foldedIntraVectorOffset) {
703
- if (isUnalignedEmulation) {
704
- result =
705
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
706
- *foldedIntraVectorOffset, origElements);
707
- }
708
- } else {
728
+ if (!foldedIntraVectorOffset) {
709
729
auto zeros = rewriter.create <arith::ConstantOp>(
710
730
loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
711
731
result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
712
732
linearizedInfo.intraDataOffset ,
713
733
origElements);
734
+ } else if (isUnalignedEmulation) {
735
+ result =
736
+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
737
+ *foldedIntraVectorOffset, origElements);
714
738
}
715
739
rewriter.replaceOp (op, result);
716
740
0 commit comments