Skip to content

Commit 5eebcc0

Browse files
committed
Implement dynamic indexing for MaskedLoads
1 parent a6fdfef commit 5eebcc0

File tree

2 files changed

+120
-33
lines changed

2 files changed

+120
-33
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
5353
Location loc, Value mask,
5454
int origElements, int scale,
5555
int intraDataOffset = 0) {
56+
assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
5657
auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
5758

5859
Operation *maskOp = mask.getDefiningOp();
@@ -182,6 +183,27 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
182183
return dest;
183184
}
184185

186+
/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
187+
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
188+
TypedValue<VectorType> source,
189+
Value dest, OpFoldResult destOffsetVar,
190+
int64_t length) {
191+
assert(length > 0 && "length must be greater than 0");
192+
for (int i = 0; i < length; ++i) {
193+
Value insertLoc;
194+
if (i == 0) {
195+
insertLoc = destOffsetVar.dyn_cast<Value>();
196+
} else {
197+
insertLoc = rewriter.create<arith::AddIOp>(
198+
loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
199+
rewriter.create<arith::ConstantIndexOp>(loc, i));
200+
}
201+
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
202+
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
203+
}
204+
return dest;
205+
}
206+
185207
/// Returns the op sequence for an emulated sub-byte data type vector load.
186208
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
187209
/// The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +221,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
199221
return rewriter.create<vector::BitCastOp>(
200222
loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
201223
newLoad);
202-
};
224+
}
203225

204226
namespace {
205227

@@ -546,29 +568,30 @@ struct ConvertVectorMaskedLoad final
546568
? getConstantIntValue(linearizedInfo.intraDataOffset)
547569
: 0;
548570

549-
if (!foldedIntraVectorOffset) {
550-
// unimplemented case for dynamic intra vector offset
551-
return failure();
552-
}
553-
554-
FailureOr<Operation *> newMask =
555-
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
556-
*foldedIntraVectorOffset);
571+
auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
572+
FailureOr<Operation *> newMask = getCompressedMaskOp(
573+
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
557574
if (failed(newMask))
558575
return failure();
559576

577+
Value passthru = op.getPassThru();
578+
560579
auto numElements =
561-
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
580+
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
562581
auto loadType = VectorType::get(numElements, newElementType);
563582
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
564583

565-
Value passthru = op.getPassThru();
566-
if (isUnalignedEmulation) {
567-
// create an empty vector of the new type
568-
auto emptyVector = rewriter.create<arith::ConstantOp>(
569-
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
570-
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
571-
*foldedIntraVectorOffset);
584+
auto emptyVector = rewriter.create<arith::ConstantOp>(
585+
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
586+
if (foldedIntraVectorOffset) {
587+
if (isUnalignedEmulation) {
588+
passthru = staticallyInsertSubvector(
589+
rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
590+
}
591+
} else {
592+
passthru = dynamicallyInsertSubVector(
593+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
594+
emptyVector, linearizedInfo.intraDataOffset, origElements);
572595
}
573596
auto newPassThru =
574597
rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +608,36 @@ struct ConvertVectorMaskedLoad final
585608
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
586609

587610
Value mask = op.getMask();
588-
if (isUnalignedEmulation) {
589-
auto newSelectMaskType =
590-
VectorType::get(numElements * scale, rewriter.getI1Type());
591-
// TODO: can fold if op's mask is constant
592-
auto emptyVector = rewriter.create<arith::ConstantOp>(
593-
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
594-
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
595-
*foldedIntraVectorOffset);
611+
auto newSelectMaskType =
612+
VectorType::get(numElements * scale, rewriter.getI1Type());
613+
// TODO: try to fold if op's mask is constant
614+
auto emptyMask = rewriter.create<arith::ConstantOp>(
615+
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
616+
if (foldedIntraVectorOffset) {
617+
if (isUnalignedEmulation) {
618+
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
619+
*foldedIntraVectorOffset);
620+
}
621+
} else {
622+
mask = dynamicallyInsertSubVector(
623+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
624+
linearizedInfo.intraDataOffset, origElements);
596625
}
597626

598627
Value result =
599628
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
600-
601-
if (isUnalignedEmulation) {
602-
result =
603-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
604-
*foldedIntraVectorOffset, origElements);
629+
if (foldedIntraVectorOffset) {
630+
if (isUnalignedEmulation) {
631+
result =
632+
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
633+
*foldedIntraVectorOffset, origElements);
634+
}
635+
} else {
636+
auto resultVector = rewriter.create<arith::ConstantOp>(
637+
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
638+
result = dynamicallyExtractSubVector(
639+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
640+
linearizedInfo.intraDataOffset, origElements);
605641
}
606642
rewriter.replaceOp(op, result);
607643

@@ -659,10 +695,9 @@ struct ConvertVectorTransferRead final
659695
? getConstantIntValue(linearizedInfo.intraDataOffset)
660696
: 0;
661697

662-
auto maxIntraVectorOffset =
663-
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
698+
auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
664699
auto numElements =
665-
llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
700+
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
666701

667702
auto newRead = rewriter.create<vector::TransferReadOp>(
668703
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,55 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
183183
// CHECK: %[[C2:.+]] = arith.constant 2 : index
184184
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
185185
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
186+
// -----
187+
188+
func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
189+
%0 = memref.alloc() : memref<3x3xi2>
190+
%cst = arith.constant dense<0> : vector<3x3xi2>
191+
%c2 = arith.constant 2 : index
192+
%mask = vector.constant_mask [3] : vector<3xi1>
193+
%1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
194+
memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
195+
return %1 : vector<3xi2>
196+
}
197+
198+
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
199+
// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
200+
// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
201+
// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
202+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
203+
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
204+
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
205+
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
206+
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
207+
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
208+
// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
209+
// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
210+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
211+
// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
212+
// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
213+
// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
214+
// CHECK: %[[C2:.+]] = arith.constant 2 : index
215+
// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
216+
// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
217+
// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
218+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
219+
// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
220+
// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
221+
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
222+
// extracts:
223+
// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
224+
// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
225+
// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
226+
// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
227+
// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
228+
// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
229+
// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
230+
// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
231+
// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
232+
// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
233+
// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
234+
// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
235+
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
236+
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
237+
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>

0 commit comments

Comments
 (0)