Skip to content

Commit c3c3ccc

Browse files
[MLIR] support dynamic indexing of vector.maskedload in VectorEmulateNarrowTypes (#115070)
Based on existing emulating scheme, this patch expands to support dynamic indexing by dynamically create intermediate new mask, new pass thru vector and dynamically insert the result into destination vector. the dynamic parts are constructed by multiple `vector.extract` and `vector.insert` to rearrange the original mask/passthru vector, as `vector.insert_strided_slice` and `vector.extract_strided_slice` only take static offsets and indices. Note: currently only supporting `vector.maskedload` with masks created by `vector.constant_mask`. `vector.create_mask` is currently not working. --------- Co-authored-by: hasekawa-takumi <[email protected]>
1 parent 5a09424 commit c3c3ccc

File tree

2 files changed

+132
-42
lines changed

2 files changed

+132
-42
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ using namespace mlir;
5252
///
5353
/// %mask = [1, 1, 0, 0, 0, 0]
5454
///
55-
/// will first be padded with number of `intraDataOffset` zeros:
55+
/// will first be padded in the front with number of `intraDataOffset` zeros,
56+
/// and pad zeros in the back to make the number of elements a multiple of
57+
/// `scale` (just to make it easier to compute). The new mask will be:
5658
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
5759
///
5860
/// then it will return the following new compressed mask:
@@ -62,7 +64,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
6264
Location loc, Value mask,
6365
int origElements, int scale,
6466
int intraDataOffset = 0) {
65-
auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
67+
assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
68+
auto numElements = llvm::divideCeil(intraDataOffset + origElements, scale);
6669

6770
Operation *maskOp = mask.getDefiningOp();
6871
SmallVector<vector::ExtractOp, 2> extractOps;
@@ -194,6 +197,26 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
194197
return dest;
195198
}
196199

200+
/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
201+
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
202+
TypedValue<VectorType> source,
203+
Value dest, OpFoldResult destOffsetVar,
204+
size_t length) {
205+
assert(length > 0 && "length must be greater than 0");
206+
Value destOffsetVal =
207+
getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
208+
for (size_t i = 0; i < length; ++i) {
209+
auto insertLoc = i == 0
210+
? destOffsetVal
211+
: rewriter.create<arith::AddIOp>(
212+
loc, rewriter.getIndexType(), destOffsetVal,
213+
rewriter.create<arith::ConstantIndexOp>(loc, i));
214+
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
215+
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
216+
}
217+
return dest;
218+
}
219+
197220
/// Returns the op sequence for an emulated sub-byte data type vector load.
198221
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
199222
/// The load location is given by `base` and `linearizedIndices`, and the
@@ -466,18 +489,16 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
466489
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
467490
numElements, oldElementType, newElementType);
468491

469-
if (foldedIntraVectorOffset) {
470-
if (isUnalignedEmulation) {
471-
result =
472-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
473-
*foldedIntraVectorOffset, origElements);
474-
}
475-
} else {
492+
if (!foldedIntraVectorOffset) {
476493
auto resultVector = rewriter.create<arith::ConstantOp>(
477494
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
478495
result = dynamicallyExtractSubVector(
479496
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
480497
linearizedInfo.intraDataOffset, origElements);
498+
} else if (isUnalignedEmulation) {
499+
result =
500+
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
501+
*foldedIntraVectorOffset, origElements);
481502
}
482503
rewriter.replaceOp(op, result);
483504
return success();
@@ -572,27 +593,26 @@ struct ConvertVectorMaskedLoad final
572593
? getConstantIntValue(linearizedInfo.intraDataOffset)
573594
: 0;
574595

575-
if (!foldedIntraVectorOffset) {
576-
// unimplemented case for dynamic intra vector offset
577-
return failure();
578-
}
579-
580-
FailureOr<Operation *> newMask =
581-
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
582-
*foldedIntraVectorOffset);
596+
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
597+
FailureOr<Operation *> newMask = getCompressedMaskOp(
598+
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
583599
if (failed(newMask))
584600
return failure();
585601

602+
Value passthru = op.getPassThru();
603+
586604
auto numElements =
587-
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
605+
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
588606
auto loadType = VectorType::get(numElements, newElementType);
589607
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
590608

591-
Value passthru = op.getPassThru();
592-
if (isUnalignedEmulation) {
593-
// create an empty vector of the new type
594-
auto emptyVector = rewriter.create<arith::ConstantOp>(
595-
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
609+
auto emptyVector = rewriter.create<arith::ConstantOp>(
610+
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
611+
if (!foldedIntraVectorOffset) {
612+
passthru = dynamicallyInsertSubVector(
613+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
614+
emptyVector, linearizedInfo.intraDataOffset, origElements);
615+
} else if (isUnalignedEmulation) {
596616
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
597617
*foldedIntraVectorOffset);
598618
}
@@ -611,20 +631,27 @@ struct ConvertVectorMaskedLoad final
611631
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
612632

613633
Value mask = op.getMask();
614-
if (isUnalignedEmulation) {
615-
auto newSelectMaskType =
616-
VectorType::get(numElements * scale, rewriter.getI1Type());
617-
// TODO: can fold if op's mask is constant
618-
auto emptyVector = rewriter.create<arith::ConstantOp>(
619-
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
620-
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
634+
auto newSelectMaskType =
635+
VectorType::get(numElements * scale, rewriter.getI1Type());
636+
// TODO: try to fold if op's mask is constant
637+
auto emptyMask = rewriter.create<arith::ConstantOp>(
638+
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
639+
if (!foldedIntraVectorOffset) {
640+
mask = dynamicallyInsertSubVector(
641+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
642+
linearizedInfo.intraDataOffset, origElements);
643+
} else if (isUnalignedEmulation) {
644+
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
621645
*foldedIntraVectorOffset);
622646
}
623647

624648
Value result =
625649
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
626-
627-
if (isUnalignedEmulation) {
650+
if (!foldedIntraVectorOffset) {
651+
result = dynamicallyExtractSubVector(
652+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
653+
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
654+
} else if (isUnalignedEmulation) {
628655
result =
629656
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
630657
*foldedIntraVectorOffset, origElements);
@@ -685,10 +712,9 @@ struct ConvertVectorTransferRead final
685712
? getConstantIntValue(linearizedInfo.intraDataOffset)
686713
: 0;
687714

688-
auto maxIntraVectorOffset =
689-
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
715+
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
690716
auto numElements =
691-
llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
717+
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
692718

693719
auto newRead = rewriter.create<vector::TransferReadOp>(
694720
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -699,18 +725,16 @@ struct ConvertVectorTransferRead final
699725
loc, VectorType::get(numElements * scale, oldElementType), newRead);
700726

701727
Value result = bitCast->getResult(0);
702-
if (foldedIntraVectorOffset) {
703-
if (isUnalignedEmulation) {
704-
result =
705-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
706-
*foldedIntraVectorOffset, origElements);
707-
}
708-
} else {
728+
if (!foldedIntraVectorOffset) {
709729
auto zeros = rewriter.create<arith::ConstantOp>(
710730
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
711731
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
712732
linearizedInfo.intraDataOffset,
713733
origElements);
734+
} else if (isUnalignedEmulation) {
735+
result =
736+
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
737+
*foldedIntraVectorOffset, origElements);
714738
}
715739
rewriter.replaceOp(op, result);
716740

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,69 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
183183
// CHECK: %[[C2:.+]] = arith.constant 2 : index
184184
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
185185
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
186+
// -----
187+
188+
func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
189+
%0 = memref.alloc() : memref<3x3xi2>
190+
%cst = arith.constant dense<0> : vector<3x3xi2>
191+
%c2 = arith.constant 2 : index
192+
%mask = vector.constant_mask [3] : vector<3xi1>
193+
%1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
194+
memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
195+
return %1 : vector<3xi2>
196+
}
197+
198+
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
199+
// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
200+
// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
201+
// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
202+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
203+
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
204+
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
205+
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
206+
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
207+
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
208+
209+
// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
210+
// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
211+
// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
212+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
213+
// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
214+
// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
215+
// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
216+
// CHECK: %[[C2:.+]] = arith.constant 2 : index
217+
// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
218+
// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
219+
// CHECK: %[[NEW_PASSTHRU:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
220+
221+
// Bitcast the new passthru vector to emulated i8 vector
222+
// CHECK: %[[BCAST_PASSTHRU:.+]] = vector.bitcast %[[NEW_PASSTHRU]] : vector<8xi2> to vector<2xi8>
223+
224+
// Use the emulated i8 vector for masked load from the source memory
225+
// CHECK: %[[SOURCE:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BCAST_PASSTHRU]]
226+
// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
227+
228+
// Bitcast back to i2 vector
229+
// CHECK: %[[BCAST_MASKLOAD:.+]] = vector.bitcast %[[SOURCE]] : vector<2xi8> to vector<8xi2>
230+
231+
// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
232+
233+
// Create a mask vector
234+
// Note that if indices are known then we can fold the part generating mask.
235+
// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
236+
// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
237+
// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
238+
// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
239+
// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
240+
// CHECK: %[[NEW_MASK:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
241+
242+
// Select the effective part from the source and passthru vectors
243+
// CHECK: %[[SELECT:.+]] = arith.select %[[NEW_MASK]], %[[BCAST_MASKLOAD]], %[[NEW_PASSTHRU]] : vector<8xi1>, vector<8xi2>
244+
245+
// Finally, insert the selected parts into actual passthru vector.
246+
// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
247+
// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[PTH]] [0] : i2 into vector<3xi2>
248+
// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
249+
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
250+
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
251+
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>

0 commit comments

Comments
 (0)