Skip to content

[MLIR] support dynamic indexing of vector.maskedload in VectorEmulateNarrowTypes #115070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 66 additions & 42 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ using namespace mlir;
///
/// %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded with number of `intraDataOffset` zeros:
/// will first be padded in the front with number of `intraDataOffset` zeros,
/// and pad zeros in the back to make the number of elements a multiple of
/// `scale` (just to make it easier to compute). The new mask will be:
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
///
/// then it will return the following new compressed mask:
Expand All @@ -53,7 +55,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
int intraDataOffset = 0) {
auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not clear from the method name ...

What are origElements, scale and intraDataOffset?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • origElements is the number of elements of the subbyte vector
  • scale is byte-emulated element type size / original element type size. For example, if the original elem type is i2, then the scale is sizeof(i8)/sizeof(i2) = 4.
  • intraDataOffset is the element offset into the emulated byte. For example, to extract the second slice of vector<3xi2> out from a vector<3x3xi2> (here we assume the subbyte type elements are stored in memory packed), we would need to load 2 bytes (the first and second byte), and extract bit [7, 14) out from it. so the first 3 elements are irrelevant in this case, hence intraDataOffset == 3 in such case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, this was not clear to me at all :)

I was trying to understand all of this a bit better and am just thinking that this logic needs TLC. The comment for this method needs updating to capture the info that you shared above. I think that it would also be good to provide more descriptive argument names.

Now, I appreciate that it wasn't you who wrote this to begin with and updating this shouldn't be a blocker for this PR. Some help would be appreciated. Also, I want to help:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my attempt to improve the comments and input variable names:

Please let me know whether that makes sense to you, and any feedback is welcome.

Note, I've also created these two:

(again, your feedback would be appreciated). Last, but not least, this example seems off. In particular:

/// [Comment from Andrzej] 6 elements
///  %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded with number of `intraDataOffset` zeros:
/// [Comment from Andrzej] 8 elements != 6 + 1
///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]

Shouldn't the padded mask be: %mask = [0, 1, 1, 0, 0, 0, 0] (7 elements)?

Btw, thanks so much for working on this - your efforts are truly appreciated! Please don’t let my comments (and appetite to improve things overall) give you any other impression 😅.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right! Here I just exposed some intermediate calculating details to the comment, as in this case scale == 2 so making the padded mask a multiple of scale in the intermediary result is easier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly updated the comment part. can you take a look at it again?

auto numElements = llvm::divideCeil(intraDataOffset + origElements, scale);

Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
Expand Down Expand Up @@ -185,6 +188,26 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
return dest;
}

/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
TypedValue<VectorType> source,
Value dest, OpFoldResult destOffsetVar,
size_t length) {
assert(length > 0 && "length must be greater than 0");
Value destOffsetVal =
getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
for (size_t i = 0; i < length; ++i) {
auto insertLoc = i == 0
? destOffsetVal
: rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), destOffsetVal,
rewriter.create<arith::ConstantIndexOp>(loc, i));
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
}
return dest;
}

/// Returns the op sequence for an emulated sub-byte data type vector load.
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
/// The load location is given by `base` and `linearizedIndices`, and the
Expand Down Expand Up @@ -443,18 +466,16 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
numElements, oldElementType, newElementType);

if (foldedIntraVectorOffset) {
if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
} else {
if (!foldedIntraVectorOffset) {
auto resultVector = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
return success();
Expand Down Expand Up @@ -549,27 +570,26 @@ struct ConvertVectorMaskedLoad final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;

if (!foldedIntraVectorOffset) {
// unimplemented case for dynamic intra vector offset
return failure();
}

FailureOr<Operation *> newMask =
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
*foldedIntraVectorOffset);
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
FailureOr<Operation *> newMask = getCompressedMaskOp(
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
if (failed(newMask))
return failure();

Value passthru = op.getPassThru();

auto numElements =
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto loadType = VectorType::get(numElements, newElementType);
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);

Value passthru = op.getPassThru();
if (isUnalignedEmulation) {
// create an empty vector of the new type
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
if (!foldedIntraVectorOffset) {
passthru = dynamicallyInsertSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
emptyVector, linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
Expand All @@ -588,20 +608,27 @@ struct ConvertVectorMaskedLoad final
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);

Value mask = op.getMask();
if (isUnalignedEmulation) {
auto newSelectMaskType =
VectorType::get(numElements * scale, rewriter.getI1Type());
// TODO: can fold if op's mask is constant
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
auto newSelectMaskType =
VectorType::get(numElements * scale, rewriter.getI1Type());
// TODO: try to fold if op's mask is constant
auto emptyMask = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
if (!foldedIntraVectorOffset) {
mask = dynamicallyInsertSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}

Value result =
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);

if (isUnalignedEmulation) {
if (!foldedIntraVectorOffset) {
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
Expand Down Expand Up @@ -662,10 +689,9 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;

auto maxIntraVectorOffset =
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
llvm::divideCeil(maxIntraDataOffset + origElements, scale);

auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
Expand All @@ -676,18 +702,16 @@ struct ConvertVectorTransferRead final
loc, VectorType::get(numElements * scale, oldElementType), newRead);

Value result = bitCast->getResult(0);
if (foldedIntraVectorOffset) {
if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
} else {
if (!foldedIntraVectorOffset) {
auto zeros = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
linearizedInfo.intraDataOffset,
origElements);
} else if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);

Expand Down
66 changes: 66 additions & 0 deletions mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,69 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
// -----

func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%cst = arith.constant dense<0> : vector<3x3xi2>
%c2 = arith.constant 2 : index
%mask = vector.constant_mask [3] : vector<3xi1>
%1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
return %1 : vector<3xi2>
}

// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>

// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
// CHECK: %[[NEW_PASSTHRU:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>

// Bitcast the new passthru vector to emulated i8 vector
// CHECK: %[[BCAST_PASSTHRU:.+]] = vector.bitcast %[[NEW_PASSTHRU]] : vector<8xi2> to vector<2xi8>

// Use the emulated i8 vector for masked load from the source memory
// CHECK: %[[SOURCE:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BCAST_PASSTHRU]]
// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>

// Bitcast back to i2 vector
// CHECK: %[[BCAST_MASKLOAD:.+]] = vector.bitcast %[[SOURCE]] : vector<2xi8> to vector<8xi2>

// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>

// Create a mask vector
// Note that if indices are known then we can fold the part generating mask.
// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
// CHECK: %[[NEW_MASK:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>

// Select the effective part from the source and passthru vectors
// CHECK: %[[SELECT:.+]] = arith.select %[[NEW_MASK]], %[[BCAST_MASKLOAD]], %[[NEW_PASSTHRU]] : vector<8xi1>, vector<8xi2>

// Finally, insert the selected parts into actual passthru vector.
// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[PTH]] [0] : i2 into vector<3xi2>
// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
Loading