@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
53
53
Location loc, Value mask,
54
54
int origElements, int scale,
55
55
int intraDataOffset = 0 ) {
56
+ assert (intraDataOffset < scale && " intraDataOffset must be less than scale" );
56
57
auto numElements = (intraDataOffset + origElements + scale - 1 ) / scale;
57
58
58
59
Operation *maskOp = mask.getDefiningOp ();
@@ -182,6 +183,27 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
182
183
return dest;
183
184
}
184
185
186
+ // / Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
187
+ static Value dynamicallyInsertSubVector (RewriterBase &rewriter, Location loc,
188
+ TypedValue<VectorType> source,
189
+ Value dest, OpFoldResult destOffsetVar,
190
+ int64_t length) {
191
+ assert (length > 0 && " length must be greater than 0" );
192
+ for (int i = 0 ; i < length; ++i) {
193
+ Value insertLoc;
194
+ if (i == 0 ) {
195
+ insertLoc = destOffsetVar.dyn_cast <Value>();
196
+ } else {
197
+ insertLoc = rewriter.create <arith::AddIOp>(
198
+ loc, rewriter.getIndexType (), destOffsetVar.dyn_cast <Value>(),
199
+ rewriter.create <arith::ConstantIndexOp>(loc, i));
200
+ }
201
+ auto extractOp = rewriter.create <vector::ExtractOp>(loc, source, i);
202
+ dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, insertLoc);
203
+ }
204
+ return dest;
205
+ }
206
+
185
207
// / Returns the op sequence for an emulated sub-byte data type vector load.
186
208
// / specifically, use `emulatedElemType` for loading a vector of `origElemType`.
187
209
// / The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +221,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
199
221
return rewriter.create <vector::BitCastOp>(
200
222
loc, VectorType::get (numEmultedElementsToLoad * scale, origElemType),
201
223
newLoad);
202
- };
224
+ }
203
225
204
226
namespace {
205
227
@@ -546,29 +568,30 @@ struct ConvertVectorMaskedLoad final
546
568
? getConstantIntValue (linearizedInfo.intraDataOffset )
547
569
: 0 ;
548
570
549
- if (!foldedIntraVectorOffset) {
550
- // unimplemented case for dynamic intra vector offset
551
- return failure ();
552
- }
553
-
554
- FailureOr<Operation *> newMask =
555
- getCompressedMaskOp (rewriter, loc, op.getMask (), origElements, scale,
556
- *foldedIntraVectorOffset);
571
+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
572
+ FailureOr<Operation *> newMask = getCompressedMaskOp (
573
+ rewriter, loc, op.getMask (), origElements, scale, maxIntraDataOffset);
557
574
if (failed (newMask))
558
575
return failure ();
559
576
577
+ Value passthru = op.getPassThru ();
578
+
560
579
auto numElements =
561
- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
580
+ llvm::divideCeil (maxIntraDataOffset + origElements, scale);
562
581
auto loadType = VectorType::get (numElements, newElementType);
563
582
auto newBitcastType = VectorType::get (numElements * scale, oldElementType);
564
583
565
- Value passthru = op.getPassThru ();
566
- if (isUnalignedEmulation) {
567
- // create an empty vector of the new type
568
- auto emptyVector = rewriter.create <arith::ConstantOp>(
569
- loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
570
- passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
571
- *foldedIntraVectorOffset);
584
+ auto emptyVector = rewriter.create <arith::ConstantOp>(
585
+ loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
586
+ if (foldedIntraVectorOffset) {
587
+ if (isUnalignedEmulation) {
588
+ passthru = staticallyInsertSubvector (
589
+ rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
590
+ }
591
+ } else {
592
+ passthru = dynamicallyInsertSubVector (
593
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
594
+ emptyVector, linearizedInfo.intraDataOffset , origElements);
572
595
}
573
596
auto newPassThru =
574
597
rewriter.create <vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +608,36 @@ struct ConvertVectorMaskedLoad final
585
608
rewriter.create <vector::BitCastOp>(loc, newBitcastType, newLoad);
586
609
587
610
Value mask = op.getMask ();
588
- if (isUnalignedEmulation) {
589
- auto newSelectMaskType =
590
- VectorType::get (numElements * scale, rewriter.getI1Type ());
591
- // TODO: can fold if op's mask is constant
592
- auto emptyVector = rewriter.create <arith::ConstantOp>(
593
- loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
594
- mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyVector,
595
- *foldedIntraVectorOffset);
611
+ auto newSelectMaskType =
612
+ VectorType::get (numElements * scale, rewriter.getI1Type ());
613
+ // TODO: try to fold if op's mask is constant
614
+ auto emptyMask = rewriter.create <arith::ConstantOp>(
615
+ loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
616
+ if (foldedIntraVectorOffset) {
617
+ if (isUnalignedEmulation) {
618
+ mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
619
+ *foldedIntraVectorOffset);
620
+ }
621
+ } else {
622
+ mask = dynamicallyInsertSubVector (
623
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
624
+ linearizedInfo.intraDataOffset , origElements);
596
625
}
597
626
598
627
Value result =
599
628
rewriter.create <arith::SelectOp>(loc, mask, bitCast, passthru);
600
-
601
- if (isUnalignedEmulation) {
602
- result =
603
- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
604
- *foldedIntraVectorOffset, origElements);
629
+ if (foldedIntraVectorOffset) {
630
+ if (isUnalignedEmulation) {
631
+ result =
632
+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
633
+ *foldedIntraVectorOffset, origElements);
634
+ }
635
+ } else {
636
+ auto resultVector = rewriter.create <arith::ConstantOp>(
637
+ loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
638
+ result = dynamicallyExtractSubVector (
639
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
640
+ linearizedInfo.intraDataOffset , origElements);
605
641
}
606
642
rewriter.replaceOp (op, result);
607
643
@@ -659,10 +695,9 @@ struct ConvertVectorTransferRead final
659
695
? getConstantIntValue (linearizedInfo.intraDataOffset )
660
696
: 0 ;
661
697
662
- auto maxIntraVectorOffset =
663
- foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1 ;
698
+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
664
699
auto numElements =
665
- llvm::divideCeil (maxIntraVectorOffset + origElements, scale);
700
+ llvm::divideCeil (maxIntraDataOffset + origElements, scale);
666
701
667
702
auto newRead = rewriter.create <vector::TransferReadOp>(
668
703
loc, VectorType::get (numElements, newElementType), adaptor.getSource (),
0 commit comments