Skip to content

Commit af7290e

Browse files
committed
[MLIR] Fix VectorEmulateNarrowType constant op mask bug
This commit adds support for handling mask constants generated by the `arith.constant` op in the `VectorEmulateNarrowType` pattern. Previously, this pattern would not match due to the lack of mask constant handling in `getCompressedMaskOp`. The changes include: 1. Updating `getCompressedMaskOp` to recognize and handle `arith.constant` ops as mask value sources. 2. Handling cases where the mask is not aligned with the emulated load width. The compressed mask is adjusted to account for the offset. Limitations: - The arith.constant op can only have 1-dimensional constant values. Resolves: #115742 Signed-off-by: Alan Li <[email protected]>
1 parent ec066d3 commit af7290e

File tree

3 files changed

+170
-54
lines changed

3 files changed

+170
-54
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 107 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
#include "mlir/IR/OpDefinition.h"
3131
#include "mlir/IR/TypeUtilities.h"
3232
#include "mlir/IR/Value.h"
33+
#include "mlir/Support/LLVM.h"
3334
#include "mlir/Transforms/DialectConversion.h"
3435
#include "llvm/ADT/SmallVector.h"
3536
#include "llvm/Support/Debug.h"
37+
#include "llvm/Support/LogicalResult.h"
3638
#include "llvm/Support/MathExtras.h"
3739
#include "llvm/Support/raw_ostream.h"
3840
#include <cstdint>
@@ -75,83 +77,134 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7577
int numSrcElemsPerDest,
7678
int numFrontPadElems = 0) {
7779

78-
assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
80+
assert(numFrontPadElems < numSrcElemsPerDest &&
81+
"numFrontPadElems must be less than numSrcElemsPerDest");
7982

8083
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
8184
numSrcElemsPerDest;
8285

8386
Operation *maskOp = mask.getDefiningOp();
8487
SmallVector<vector::ExtractOp, 2> extractOps;
8588
// Finding the mask creation operation.
86-
while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
89+
while (maskOp &&
90+
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
91+
maskOp)) {
8792
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
8893
maskOp = extractOp.getVector().getDefiningOp();
8994
extractOps.push_back(extractOp);
9095
}
9196
}
92-
auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
93-
auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
94-
if (!createMaskOp && !constantMaskOp)
97+
98+
if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
99+
maskOp))
95100
return failure();
96101

97102
// Computing the "compressed" mask. All the emulation logic (i.e. computing
98103
// new mask index) only happens on the last dimension of the vectors.
99-
Operation *newMask = nullptr;
100-
SmallVector<int64_t> shape(
104+
SmallVector<int64_t> maskShape(
101105
cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
102-
shape.back() = numElements;
103-
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
104-
if (createMaskOp) {
105-
OperandRange maskOperands = createMaskOp.getOperands();
106-
size_t numMaskOperands = maskOperands.size();
107-
AffineExpr s0;
108-
bindSymbols(rewriter.getContext(), s0);
109-
s0 = s0 + numSrcElemsPerDest - 1;
110-
s0 = s0.floorDiv(numSrcElemsPerDest);
111-
OpFoldResult origIndex =
112-
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
113-
OpFoldResult maskIndex =
114-
affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
115-
SmallVector<Value> newMaskOperands(maskOperands.drop_back());
116-
newMaskOperands.push_back(
117-
getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
118-
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
119-
newMaskOperands);
120-
} else if (constantMaskOp) {
121-
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
122-
size_t numMaskOperands = maskDimSizes.size();
123-
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
124-
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
125-
int64_t maskIndex =
126-
llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
127-
128-
// TODO: we only want the mask between [startIndex, maskIndex] to be true,
129-
// the rest are false.
130-
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
131-
return failure();
132-
133-
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
134-
newMaskDimSizes.push_back(maskIndex);
135-
136-
if (numFrontPadElems == 0) {
137-
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
138-
newMaskDimSizes);
139-
} else {
140-
SmallVector<bool> newMaskValues;
141-
for (int64_t i = 0; i < numElements; ++i)
142-
newMaskValues.push_back(i >= startIndex && i < maskIndex);
143-
auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
144-
newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
145-
}
146-
}
106+
maskShape.back() = numElements;
107+
auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
108+
std::optional<Operation *> newMask =
109+
TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
110+
.Case<vector::CreateMaskOp>(
111+
[&](auto createMaskOp) -> std::optional<Operation *> {
112+
OperandRange maskOperands = createMaskOp.getOperands();
113+
size_t numMaskOperands = maskOperands.size();
114+
AffineExpr s0;
115+
bindSymbols(rewriter.getContext(), s0);
116+
s0 = s0 + numSrcElemsPerDest - 1;
117+
s0 = s0.floorDiv(numSrcElemsPerDest);
118+
OpFoldResult origIndex =
119+
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
120+
OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply(
121+
rewriter, loc, s0, origIndex);
122+
SmallVector<Value> newMaskOperands(maskOperands.drop_back());
123+
newMaskOperands.push_back(
124+
getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
125+
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
126+
newMaskOperands);
127+
})
128+
.Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
129+
-> std::optional<Operation *> {
130+
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
131+
size_t numMaskOperands = maskDimSizes.size();
132+
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
133+
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
134+
int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
135+
numSrcElemsPerDest);
136+
137+
// TODO: we only want the mask between [startIndex, maskIndex]
138+
// to be true, the rest are false.
139+
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
140+
return std::nullopt;
141+
142+
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
143+
newMaskDimSizes.push_back(maskIndex);
144+
145+
if (numFrontPadElems == 0)
146+
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
147+
newMaskDimSizes);
148+
149+
SmallVector<bool> newMaskValues;
150+
for (int64_t i = 0; i < numElements; ++i)
151+
newMaskValues.push_back(i >= startIndex && i < maskIndex);
152+
auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
153+
return rewriter.create<arith::ConstantOp>(loc, newMaskType,
154+
denseAttr);
155+
})
156+
.Case<arith::ConstantOp>([&](auto constantOp)
157+
-> std::optional<Operation *> {
158+
// TODO: Support multiple dimensions.
159+
if (maskShape.size() != 1)
160+
return std::nullopt;
161+
// Rearrange the original mask values to cover the whole potential
162+
// loading region. For example, in the case of using byte-size for
163+
// emulation, given the following mask:
164+
//
165+
// %mask = [false, true, false, true, false, false]
166+
//
167+
// With front offset of 1, the mask will be padded 0s in the front
168+
// and back so that:
169+
// 1. It is aligned with the effective loading bits
170+
// 2. Its length is multiple of `numSrcElemPerDest` (and the total
171+
// coverage size is mulitiple of bytes). The new mask will be like
172+
// this before compressing:
173+
//
174+
// %new_mask = [false, false, true, false, true, false, false,
175+
// false]
176+
auto denseAttr =
177+
dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
178+
if (!denseAttr)
179+
return std::nullopt;
180+
SmallVector<bool> maskValues(numFrontPadElems, false);
181+
maskValues.append(denseAttr.template value_begin<bool>(),
182+
denseAttr.template value_end<bool>());
183+
maskValues.resize(numElements * numSrcElemsPerDest, false);
184+
185+
// Compressing by combining every `numSrcElemsPerDest` elements:
186+
SmallVector<bool> compressedMaskValues;
187+
for (size_t i = 0; i < maskValues.size(); i += numSrcElemsPerDest) {
188+
bool combinedValue = false;
189+
for (int j = 0; j < numSrcElemsPerDest; ++j) {
190+
combinedValue |= maskValues[i + j];
191+
}
192+
compressedMaskValues.push_back(combinedValue);
193+
}
194+
return rewriter.create<arith::ConstantOp>(
195+
loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
196+
});
197+
198+
if (!newMask)
199+
return failure();
147200

148201
while (!extractOps.empty()) {
149202
newMask = rewriter.create<vector::ExtractOp>(
150-
loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
203+
loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
151204
extractOps.pop_back();
152205
}
153206

154-
return newMask;
207+
return *newMask;
155208
}
156209

157210
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,41 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
249249
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
250250
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
251251
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
252+
253+
// -----
254+
255+
func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
256+
%0 = memref.alloc() : memref<3x5xi2>
257+
%mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
258+
%c0 = arith.constant 0 : index
259+
%c1 = arith.constant 1 : index
260+
%1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
261+
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
262+
return %1 : vector<5xi2>
263+
}
264+
265+
// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
266+
// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
267+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
268+
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
269+
270+
// CHECK: %[[CST0:.+]] = arith.constant dense<true> : vector<2xi1>
271+
// CHECK: %[[CST1:.+]] = arith.constant dense<0> : vector<8xi2>
272+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[PTH]], %[[CST1]]
273+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
274+
275+
// Emulated masked load from alloc:
276+
// CHECK: %[[BCAST:.+]] = vector.bitcast %[[INSERT]] : vector<8xi2> to vector<2xi8>
277+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
278+
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[CST0]], %[[BCAST]]
279+
// CHECK: %[[BCAST2:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
280+
281+
// Select from emulated loaded vector and passthru vector:
282+
// TODO: fold this part if possible.
283+
// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
284+
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[MASK]], %[[CST2]]
285+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
286+
// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BCAST2]], %[[INSERT]] : vector<8xi1>, vector<8xi2>
287+
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SELECT]]
288+
// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
289+
// CHECK: return %[[EXTRACT]] : vector<5xi2>

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,3 +624,28 @@ func.func @vector_maskedstore_i4_constant_mask(
624624
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
625625
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
626626
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
627+
628+
// -----
629+
630+
func.func @vector_maskedload_i4_arith_constant(%passthru: vector<8xi4>) -> vector<8xi4> {
631+
%0 = memref.alloc() : memref<3x8xi4>
632+
%cst = arith.constant dense<0> : vector<8xi4>
633+
%mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
634+
%c0 = arith.constant 0 : index
635+
%1 = vector.maskedload %0[%c0, %c0], %mask, %passthru :
636+
memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
637+
return %1 : vector<8xi4>
638+
}
639+
640+
// CHECK: func @vector_maskedload_i4_arith_constant(
641+
// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
642+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
643+
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
644+
// CHECK: %[[COMP_MASK:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
645+
// CHECK: %[[PTHU_UPCAST:.+]] = vector.bitcast %[[PASSTHRU]] : vector<8xi4> to vector<4xi8>
646+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
647+
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C0]]], %[[COMP_MASK]], %[[PTHU_UPCAST]]
648+
// CHECK-SAME : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
649+
// CHECK: %[[LOAD_DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
650+
// CHECK: %[[SELECT:.+]] = arith.select %[[MASK]], %[[LOAD_DOWNCAST]], %[[PASSTHRU]] : vector<8xi1>, vector<8xi4>
651+
// CHECK: return %[[SELECT]] : vector<8xi4>

0 commit comments

Comments
 (0)