Skip to content

Commit 7f82c90

Browse files
authored
[mlir][vector] Add support for vector.maskedstore sub-type emulation. (#73871)
The idea is similar to vector.maskedload + vector.store emulation. What the emulation does is: 1. Get a compressed mask and load the data from destination. 2. Bitcast the data to original vector type. 3. Select values between `op.valueToStore` and the data from load using original mask. 4. Bitcast the new value and store it to destination using compressed masked.
1 parent a112921 commit 7f82c90

File tree

2 files changed

+241
-62
lines changed

2 files changed

+241
-62
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 169 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,79 @@ using namespace mlir;
3232
#define DBGSNL() (llvm::dbgs() << "\n")
3333
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
3434

35+
/// Returns a compressed mask. The mask value is set only if any mask is present
36+
/// in the scale range. E.g., if `scale` equals to 2, the following mask:
37+
///
38+
/// %mask = [1, 1, 1, 0, 0, 0]
39+
///
40+
/// will return the following new compressed mask:
41+
///
42+
/// %mask = [1, 1, 0]
43+
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
44+
Location loc, Value mask,
45+
int origElements, int scale) {
46+
auto numElements = (origElements + scale - 1) / scale;
47+
48+
Operation *maskOp = mask.getDefiningOp();
49+
SmallVector<vector::ExtractOp, 2> extractOps;
50+
// Finding the mask creation operation.
51+
while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
52+
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
53+
maskOp = extractOp.getVector().getDefiningOp();
54+
extractOps.push_back(extractOp);
55+
}
56+
}
57+
auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
58+
auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
59+
if (!createMaskOp && !constantMaskOp)
60+
return failure();
61+
62+
// Computing the "compressed" mask. All the emulation logic (i.e. computing
63+
// new mask index) only happens on the last dimension of the vectors.
64+
Operation *newMask = nullptr;
65+
SmallVector<int64_t> shape(
66+
maskOp->getResultTypes()[0].cast<VectorType>().getShape());
67+
shape.back() = numElements;
68+
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
69+
if (createMaskOp) {
70+
OperandRange maskOperands = createMaskOp.getOperands();
71+
size_t numMaskOperands = maskOperands.size();
72+
AffineExpr s0;
73+
bindSymbols(rewriter.getContext(), s0);
74+
s0 = s0 + scale - 1;
75+
s0 = s0.floorDiv(scale);
76+
OpFoldResult origIndex =
77+
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
78+
OpFoldResult maskIndex =
79+
affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
80+
SmallVector<Value> newMaskOperands(maskOperands.drop_back());
81+
newMaskOperands.push_back(
82+
getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
83+
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
84+
newMaskOperands);
85+
} else if (constantMaskOp) {
86+
ArrayRef<Attribute> maskDimSizes =
87+
constantMaskOp.getMaskDimSizes().getValue();
88+
size_t numMaskOperands = maskDimSizes.size();
89+
auto origIndex =
90+
cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
91+
IntegerAttr maskIndexAttr =
92+
rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
93+
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
94+
newMaskDimSizes.push_back(maskIndexAttr);
95+
newMask = rewriter.create<vector::ConstantMaskOp>(
96+
loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
97+
}
98+
99+
while (!extractOps.empty()) {
100+
newMask = rewriter.create<vector::ExtractOp>(
101+
loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
102+
extractOps.pop_back();
103+
}
104+
105+
return newMask;
106+
}
107+
35108
namespace {
36109

37110
//===----------------------------------------------------------------------===//
@@ -99,6 +172,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
99172
}
100173
};
101174

175+
//===----------------------------------------------------------------------===//
176+
// ConvertVectorMaskedStore
177+
//===----------------------------------------------------------------------===//
178+
179+
struct ConvertVectorMaskedStore final
180+
: OpConversionPattern<vector::MaskedStoreOp> {
181+
using OpConversionPattern::OpConversionPattern;
182+
183+
LogicalResult
184+
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
185+
ConversionPatternRewriter &rewriter) const override {
186+
187+
auto loc = op.getLoc();
188+
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
189+
Type oldElementType = op.getValueToStore().getType().getElementType();
190+
Type newElementType = convertedType.getElementType();
191+
int srcBits = oldElementType.getIntOrFloatBitWidth();
192+
int dstBits = newElementType.getIntOrFloatBitWidth();
193+
194+
if (dstBits % srcBits != 0) {
195+
return rewriter.notifyMatchFailure(
196+
op, "only dstBits % srcBits == 0 supported");
197+
}
198+
199+
int scale = dstBits / srcBits;
200+
int origElements = op.getValueToStore().getType().getNumElements();
201+
if (origElements % scale != 0)
202+
return failure();
203+
204+
auto stridedMetadata =
205+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
206+
OpFoldResult linearizedIndicesOfr;
207+
std::tie(std::ignore, linearizedIndicesOfr) =
208+
memref::getLinearizedMemRefOffsetAndSize(
209+
rewriter, loc, srcBits, dstBits,
210+
stridedMetadata.getConstifiedMixedOffset(),
211+
stridedMetadata.getConstifiedMixedSizes(),
212+
stridedMetadata.getConstifiedMixedStrides(),
213+
getAsOpFoldResult(adaptor.getIndices()));
214+
Value linearizedIndices =
215+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
216+
217+
// Load the whole data and use arith.select to handle the corner cases.
218+
// E.g., given these input values:
219+
//
220+
// %mask = [1, 1, 1, 0, 0, 0]
221+
// %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
222+
// %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
223+
//
224+
// we'll have
225+
//
226+
// expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
227+
//
228+
// %new_mask = [1, 1, 0]
229+
// %maskedload = [0x12, 0x34, 0x0]
230+
// %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
231+
// %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
232+
// %packed_data = [0x78, 0x94, 0x00]
233+
//
234+
// Using the new mask to store %packed_data results in expected output.
235+
FailureOr<Operation *> newMask =
236+
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
237+
if (failed(newMask))
238+
return failure();
239+
240+
auto numElements = (origElements + scale - 1) / scale;
241+
auto newType = VectorType::get(numElements, newElementType);
242+
auto passThru = rewriter.create<arith::ConstantOp>(
243+
loc, newType, rewriter.getZeroAttr(newType));
244+
245+
auto newLoad = rewriter.create<vector::MaskedLoadOp>(
246+
loc, newType, adaptor.getBase(), linearizedIndices,
247+
newMask.value()->getResult(0), passThru);
248+
249+
Value valueToStore = rewriter.create<vector::BitCastOp>(
250+
loc, op.getValueToStore().getType(), newLoad);
251+
valueToStore = rewriter.create<arith::SelectOp>(
252+
loc, op.getMask(), op.getValueToStore(), valueToStore);
253+
valueToStore =
254+
rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
255+
256+
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
257+
op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
258+
valueToStore);
259+
return success();
260+
}
261+
};
262+
102263
//===----------------------------------------------------------------------===//
103264
// ConvertVectorLoad
104265
//===----------------------------------------------------------------------===//
@@ -236,15 +397,13 @@ struct ConvertVectorMaskedLoad final
236397
// TODO: Currently, only the even number of elements loading is supported.
237398
// To deal with the odd number of elements, one has to extract the
238399
// subvector at the proper offset after bit-casting.
239-
240400
auto origType = op.getVectorType();
241401
auto origElements = origType.getNumElements();
242402
if (origElements % scale != 0)
243403
return failure();
244404

245405
auto stridedMetadata =
246406
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
247-
248407
OpFoldResult linearizedIndices;
249408
std::tie(std::ignore, linearizedIndices) =
250409
memref::getLinearizedMemRefOffsetAndSize(
@@ -254,74 +413,21 @@ struct ConvertVectorMaskedLoad final
254413
stridedMetadata.getConstifiedMixedStrides(),
255414
getAsOpFoldResult(adaptor.getIndices()));
256415

257-
auto numElements = (origElements + scale - 1) / scale;
258-
auto newType = VectorType::get(numElements, newElementType);
259-
260-
auto maskOp = op.getMask().getDefiningOp();
261-
SmallVector<vector::ExtractOp, 2> extractOps;
262-
// Finding the mask creation operation.
263-
while (maskOp &&
264-
!isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
265-
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
266-
maskOp = extractOp.getVector().getDefiningOp();
267-
extractOps.push_back(extractOp);
268-
}
269-
}
270-
auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
271-
auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
272-
if (!createMaskOp && !constantMaskOp)
416+
FailureOr<Operation *> newMask =
417+
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
418+
if (failed(newMask))
273419
return failure();
274420

275-
// Computing the "compressed" mask. All the emulation logic (i.e. computing
276-
// new mask index) only happens on the last dimension of the vectors.
277-
Operation *newMask = nullptr;
278-
auto shape = llvm::to_vector(
279-
maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
280-
shape.push_back(numElements);
281-
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
282-
if (createMaskOp) {
283-
auto maskOperands = createMaskOp.getOperands();
284-
auto numMaskOperands = maskOperands.size();
285-
AffineExpr s0;
286-
bindSymbols(rewriter.getContext(), s0);
287-
s0 = s0 + scale - 1;
288-
s0 = s0.floorDiv(scale);
289-
OpFoldResult origIndex =
290-
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
291-
OpFoldResult maskIndex =
292-
affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
293-
auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
294-
newMaskOperands.push_back(
295-
getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
296-
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
297-
newMaskOperands);
298-
} else if (constantMaskOp) {
299-
auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
300-
auto numMaskOperands = maskDimSizes.size();
301-
auto origIndex =
302-
cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
303-
auto maskIndex =
304-
rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
305-
auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
306-
newMaskDimSizes.push_back(maskIndex);
307-
newMask = rewriter.create<vector::ConstantMaskOp>(
308-
loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
309-
}
310-
311-
while (!extractOps.empty()) {
312-
newMask = rewriter.create<vector::ExtractOp>(
313-
loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
314-
extractOps.pop_back();
315-
}
316-
421+
auto numElements = (origElements + scale - 1) / scale;
422+
auto newType = VectorType::get(numElements, newElementType);
317423
auto newPassThru =
318424
rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
319425

320426
// Generating the new masked load.
321427
auto newLoad = rewriter.create<vector::MaskedLoadOp>(
322428
loc, newType, adaptor.getBase(),
323429
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
324-
newMask->getResult(0), newPassThru);
430+
newMask.value()->getResult(0), newPassThru);
325431

326432
// Setting the part that originally was not effectively loaded from memory
327433
// to pass through.
@@ -821,7 +927,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
821927

822928
// Populate `vector.*` conversion patterns.
823929
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
824-
ConvertVectorTransferRead>(typeConverter, patterns.getContext());
930+
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
931+
typeConverter, patterns.getContext());
825932
}
826933

827934
void vector::populateVectorNarrowTypeRewritePatterns(

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,3 +428,75 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
428428
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
429429
// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
430430
// CHECK32: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi32>, vector<1xi32>
431+
432+
// -----
433+
434+
func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
435+
%0 = memref.alloc() : memref<3x8xi8>
436+
%mask = vector.create_mask %arg2 : vector<8xi1>
437+
vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
438+
return
439+
}
440+
// Expect no conversions, i8 is supported.
441+
// CHECK: func @vector_maskedstore_i8(
442+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
443+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
444+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
445+
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
446+
// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x8xi8>
447+
// CHECK-NEXT: %[[MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
448+
// CHECK-NEXT: vector.maskedstore %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[VAL]]
449+
// CHECK-NEXT: return
450+
451+
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
452+
// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
453+
// CHECK32: func @vector_maskedstore_i8(
454+
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
455+
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
456+
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]
457+
// CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]]
458+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<6xi32>
459+
// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
460+
// CHECK32: %[[LIDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
461+
// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
462+
// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<2xi1>
463+
// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<2xi32>
464+
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]]
465+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<8xi8>
466+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
467+
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
468+
// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
469+
470+
// -----
471+
472+
func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
473+
%0 = memref.alloc() : memref<3x8xi8>
474+
%mask = vector.constant_mask [4] : vector<8xi1>
475+
vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
476+
return
477+
}
478+
// Expect no conversions, i8 is supported.
479+
// CHECK: func @vector_cst_maskedstore_i8(
480+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
481+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
482+
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
483+
// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x8xi8>
484+
// CHECK-NEXT: %[[MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
485+
// CHECK-NEXT: vector.maskedstore %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[VAL]]
486+
// CHECK-NEXT: return
487+
488+
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
489+
// CHECK32: func @vector_cst_maskedstore_i8(
490+
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
491+
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
492+
// CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]]
493+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<6xi32>
494+
// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
495+
// CHECK32: %[[LIDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
496+
// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<2xi1>
497+
// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<2xi32>
498+
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]]
499+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<8xi8>
500+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
501+
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
502+
// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]

0 commit comments

Comments
 (0)