Skip to content

[MLIR] Fix VectorEmulateNarrowType constant op mask bug #116064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 109 additions & 60 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,83 +75,134 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
int numSrcElemsPerDest,
int numFrontPadElems = 0) {

assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
assert(numFrontPadElems < numSrcElemsPerDest &&
"numFrontPadElems must be less than numSrcElemsPerDest");

auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
numSrcElemsPerDest;
auto numDestElems =
(numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
numSrcElemsPerDest;

Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
// TODO: add support to `vector.splat`.
// Finding the mask creation operation.
while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
while (maskOp &&
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
maskOp)) {
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
maskOp = extractOp.getVector().getDefiningOp();
extractOps.push_back(extractOp);
}
}
auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
if (!createMaskOp && !constantMaskOp)

if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
maskOp))
return failure();

// Computing the "compressed" mask. All the emulation logic (i.e. computing
// new mask index) only happens on the last dimension of the vectors.
Operation *newMask = nullptr;
SmallVector<int64_t> shape(
SmallVector<int64_t> maskShape(
cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
shape.back() = numElements;
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
if (createMaskOp) {
OperandRange maskOperands = createMaskOp.getOperands();
size_t numMaskOperands = maskOperands.size();
AffineExpr s0;
bindSymbols(rewriter.getContext(), s0);
s0 = s0 + numSrcElemsPerDest - 1;
s0 = s0.floorDiv(numSrcElemsPerDest);
OpFoldResult origIndex =
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
OpFoldResult maskIndex =
affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
SmallVector<Value> newMaskOperands(maskOperands.drop_back());
newMaskOperands.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
} else if (constantMaskOp) {
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
int64_t maskIndex =
llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);

// TODO: we only want the mask between [startIndex, maskIndex] to be true,
// the rest are false.
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
return failure();

SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndex);

if (numFrontPadElems == 0) {
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
newMaskDimSizes);
} else {
SmallVector<bool> newMaskValues;
for (int64_t i = 0; i < numElements; ++i)
newMaskValues.push_back(i >= startIndex && i < maskIndex);
auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
}
}
maskShape.back() = numDestElems;
auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
std::optional<Operation *> newMask =
TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
.Case<vector::CreateMaskOp>(
[&](auto createMaskOp) -> std::optional<Operation *> {
OperandRange maskOperands = createMaskOp.getOperands();
size_t numMaskOperands = maskOperands.size();
AffineExpr s0;
bindSymbols(rewriter.getContext(), s0);
s0 = s0 + numSrcElemsPerDest - 1;
s0 = s0.floorDiv(numSrcElemsPerDest);
OpFoldResult origIndex =
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply(
rewriter, loc, s0, origIndex);
SmallVector<Value> newMaskOperands(maskOperands.drop_back());
newMaskOperands.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
})
.Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
-> std::optional<Operation *> {
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
numSrcElemsPerDest);

// TODO: we only want the mask between [startIndex, maskIndex]
// to be true, the rest are false.
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
return std::nullopt;

SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndex);

if (numFrontPadElems == 0)
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
newMaskDimSizes);

SmallVector<bool> newMaskValues;
for (int64_t i = 0; i < numDestElems; ++i)
newMaskValues.push_back(i >= startIndex && i < maskIndex);
auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
return rewriter.create<arith::ConstantOp>(loc, newMaskType,
newMask);
})
.Case<arith::ConstantOp>([&](auto constantOp)
-> std::optional<Operation *> {
// TODO: Support multiple dimensions.
if (maskShape.size() != 1)
return std::nullopt;
// Rearrange the original mask values to cover the whole potential
// loading region. For example, in the case of using byte-size for
// emulation, given the following mask:
//
// %mask = [0, 1, 0, 1, 0, 0]
//
// With front offset of 1, the mask will be padded 0s in the front
// and back so that:
// 1. It is aligned with the effective loading bits
// 2. Its length is multiple of `numSrcElemPerDest` (and the total
// coverage size is mulitiple of bytes). The new mask will be like
// this before compressing:
//
// %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
auto originalMask =
cast<DenseIntElementsAttr>(constantOp.getValue());
SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
paddedMaskValues.append(originalMask.template value_begin<bool>(),
originalMask.template value_end<bool>());
paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);

// Compressing by combining every `numSrcElemsPerDest` elements:
SmallVector<bool> compressedMaskValues;
for (size_t i = 0; i < paddedMaskValues.size();
i += numSrcElemsPerDest) {
bool combinedValue = false;
for (int j = 0; j < numSrcElemsPerDest; ++j) {
combinedValue |= paddedMaskValues[i + j];
}
compressedMaskValues.push_back(combinedValue);
}
return rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
});

if (!newMask)
return failure();

while (!extractOps.empty()) {
newMask = rewriter.create<vector::ExtractOp>(
loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
extractOps.pop_back();
}

return newMask;
return *newMask;
}

/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
Expand Down Expand Up @@ -185,12 +236,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
/// `vector.insert_strided_slice`.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
Value src, Value dest, int64_t offset) {
auto srcType = cast<VectorType>(src.getType());
auto destType = cast<VectorType>(dest.getType());
[[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
[[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
"expected source and dest to be vector type");
(void)srcType;
(void)destType;
auto offsets = rewriter.getI64ArrayAttr({offset});
auto strides = rewriter.getI64ArrayAttr({1});
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,41 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>

// -----

func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it in the previous review. Could you remind me why it is labeled i4 but the test is loading i2 types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hanhanW copy/paste mis match, I should have double checked it. I will update it later in a batch refactoring.

%0 = memref.alloc() : memref<3x5xi2>
%mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
return %1 : vector<5xi2>
}

// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>

// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>

// Emulated masked load from alloc:
// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>

// Select from emulated loaded vector and passthru vector:
// TODO: fold this part if possible.
// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
// CHECK: %[[SELECT:.+]] = arith.select %[[MASK_PADDED]], %[[MASKLOAD_DOWNCAST]], %[[PTH_PADDED]] : vector<8xi1>, vector<8xi2>
// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
// CHECK: return %[[RESULT]] : vector<5xi2>
51 changes: 51 additions & 0 deletions mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,30 @@ func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passt

// -----

func.func @vector_maskedload_i4_arith_constant(%passthru: vector<8xi4>) -> vector<8xi4> {
%0 = memref.alloc() : memref<3x8xi4>
%cst = arith.constant dense<0> : vector<8xi4>
%mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
%c0 = arith.constant 0 : index
%1 = vector.maskedload %0[%c0, %c0], %mask, %passthru :
memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
return %1 : vector<8xi4>
}

// CHECK: func @vector_maskedload_i4_arith_constant(
// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>

// Emit a new, compressed mask for emulated maskedload:
// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
// CHECK: %[[PTHU_UPCAST:.+]] = vector.bitcast %[[PASSTHRU]] : vector<8xi4> to vector<4xi8>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C0]]], %[[COMPRESSED_MASK]], %[[PTHU_UPCAST]]
// CHECK: %[[LOAD_DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
// CHECK: %[[SELECT:.+]] = arith.select %[[MASK]], %[[LOAD_DOWNCAST]], %[[PASSTHRU]] : vector<8xi1>, vector<8xi4>
// CHECK: return %[[SELECT]] : vector<8xi4>

///----------------------------------------------------------------------------------------
/// vector.extract -> vector.masked_load
///----------------------------------------------------------------------------------------
Expand Down Expand Up @@ -624,3 +648,30 @@ func.func @vector_maskedstore_i4_constant_mask(
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>

// -----

func.func @vector_maskedstore_i4_arith_constant(%val_to_store: vector<8xi4>) {
%0 = memref.alloc() : memref<5x8xi4>
%cst = arith.constant dense<0> : vector<8xi4>
%mask = arith.constant dense<[false, true, true, true, true, true, false, false]> : vector<8xi1>
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
vector.maskedstore %0[%c3, %c0], %mask, %val_to_store :
memref<5x8xi4>, vector<8xi1>, vector<8xi4>
return
}

// CHECK-LABEL: func @vector_maskedstore_i4_arith_constant
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]:
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<20xi8>
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, true, false, false]> : vector<8xi1>
// %c3 * 4 bits = 12
// CHECK: %[[IDX_FLATTENED:.+]] = arith.constant 12 : index
// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<4xi8>
// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[IDX_FLATTENED]]], %[[COMPRESSED_MASK]], %[[EMPTY]]
// CHECK: %[[LOAD_UPCAST:.+]] = vector.bitcast %[[MASKEDLOAD]]
// CHECK: %[[SELECT:.+]] = arith.select %[[MASK]], %[[VAL_TO_STORE]], %[[LOAD_UPCAST]]
// CHECK: %[[SELECT_DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.maskedstore %[[ALLOC]][%[[IDX_FLATTENED]]], %[[COMPRESSED_MASK]], %[[SELECT_DOWNCAST]]
Loading