Skip to content

Commit 674261b

Browse files
authored
[mlir][Vector] Add narrow type emulation pattern for vector.maskedload (#68443)
1 parent e64e478 commit 674261b

File tree

3 files changed

+435
-4
lines changed

3 files changed

+435
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//
88
//===----------------------------------------------------------------------===//
99

10+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1011
#include "mlir/Dialect/Arith/IR/Arith.h"
1112
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
1213
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -103,6 +104,172 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
103104
}
104105
};
105106

107+
//===----------------------------------------------------------------------===//
108+
// ConvertVectorMaskedLoad
109+
//===----------------------------------------------------------------------===//
110+
111+
struct ConvertVectorMaskedLoad final
112+
: OpConversionPattern<vector::MaskedLoadOp> {
113+
using OpConversionPattern::OpConversionPattern;
114+
115+
LogicalResult
116+
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
117+
ConversionPatternRewriter &rewriter) const override {
118+
119+
auto loc = op.getLoc();
120+
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
121+
Type oldElementType = op.getType().getElementType();
122+
Type newElementType = convertedType.getElementType();
123+
int srcBits = oldElementType.getIntOrFloatBitWidth();
124+
int dstBits = newElementType.getIntOrFloatBitWidth();
125+
126+
if (dstBits % srcBits != 0) {
127+
return rewriter.notifyMatchFailure(
128+
op, "only dstBits % srcBits == 0 supported");
129+
}
130+
int scale = dstBits / srcBits;
131+
132+
// Adjust the number of elements to load when emulating narrow types,
133+
// and then cast back to the original type with vector.bitcast op.
134+
// For example, to emulate i4 to i8, the following op:
135+
//
136+
// %mask = vector.constant_mask [3] : vector<6xi1>
137+
// %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
138+
// memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
139+
//
140+
// can be replaced with
141+
//
142+
// %new_mask = vector.constant_mask [2] : vector<3xi1>
143+
// %new_pass_thru = vector.bitcast %pass_thru :
144+
// vector<6xi4> to vector<3xi8>
145+
// %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
146+
// memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
147+
// %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
148+
//
149+
// Since we are effectively loading 16 bits (2xi8) from the memref with the
150+
// new mask, while originally we only wanted to effectively load 12 bits
151+
// (3xi4) from the memref, we need to set the second half of the last i8
152+
// that was effectively loaded (i.e. the second i8) to %pass_thru.
153+
//
154+
// %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
155+
//
156+
// Given these input values:
157+
// %mask = [1, 1, 1, 0, 0, 0]
158+
// %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
159+
// %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
160+
//
161+
// we'll have:
162+
//
163+
// expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
164+
//
165+
// %new_mask = [1, 1, 0]
166+
// %new_pass_thru = [0x78, 0x9A, 0xBC]
167+
// %1 = [0x12, 0x34, 0xBC]
168+
// %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
169+
// %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
170+
//
171+
// TODO: Currently, only the even number of elements loading is supported.
172+
// To deal with the odd number of elements, one has to extract the
173+
// subvector at the proper offset after bit-casting.
174+
175+
auto origType = op.getVectorType();
176+
auto origElements = origType.getNumElements();
177+
if (origElements % scale != 0)
178+
return failure();
179+
180+
auto stridedMetadata =
181+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
182+
183+
OpFoldResult linearizedIndices;
184+
std::tie(std::ignore, linearizedIndices) =
185+
memref::getLinearizedMemRefOffsetAndSize(
186+
rewriter, loc, srcBits, dstBits,
187+
stridedMetadata.getConstifiedMixedOffset(),
188+
stridedMetadata.getConstifiedMixedSizes(),
189+
stridedMetadata.getConstifiedMixedStrides(),
190+
getAsOpFoldResult(adaptor.getIndices()));
191+
192+
auto numElements = (origElements + scale - 1) / scale;
193+
auto newType = VectorType::get(numElements, newElementType);
194+
195+
auto maskOp = op.getMask().getDefiningOp();
196+
SmallVector<vector::ExtractOp, 2> extractOps;
197+
// Finding the mask creation operation.
198+
while (maskOp &&
199+
!isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
200+
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
201+
maskOp = extractOp.getVector().getDefiningOp();
202+
extractOps.push_back(extractOp);
203+
}
204+
}
205+
auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
206+
auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
207+
if (!createMaskOp && !constantMaskOp)
208+
return failure();
209+
210+
// Computing the "compressed" mask. All the emulation logic (i.e. computing
211+
// new mask index) only happens on the last dimension of the vectors.
212+
Operation *newMask = nullptr;
213+
auto shape = llvm::to_vector(
214+
maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
215+
shape.push_back(numElements);
216+
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
217+
if (createMaskOp) {
218+
auto maskOperands = createMaskOp.getOperands();
219+
auto numMaskOperands = maskOperands.size();
220+
AffineExpr s0;
221+
bindSymbols(rewriter.getContext(), s0);
222+
s0 = s0 + scale - 1;
223+
s0 = s0.floorDiv(scale);
224+
OpFoldResult origIndex =
225+
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
226+
OpFoldResult maskIndex =
227+
affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
228+
auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
229+
newMaskOperands.push_back(
230+
getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
231+
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
232+
newMaskOperands);
233+
} else if (constantMaskOp) {
234+
auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
235+
auto numMaskOperands = maskDimSizes.size();
236+
auto origIndex =
237+
cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
238+
auto maskIndex =
239+
rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
240+
auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
241+
newMaskDimSizes.push_back(maskIndex);
242+
newMask = rewriter.create<vector::ConstantMaskOp>(
243+
loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
244+
}
245+
246+
while (!extractOps.empty()) {
247+
newMask = rewriter.create<vector::ExtractOp>(
248+
loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
249+
extractOps.pop_back();
250+
}
251+
252+
auto newPassThru =
253+
rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
254+
255+
// Generating the new masked load.
256+
auto newLoad = rewriter.create<vector::MaskedLoadOp>(
257+
loc, newType, adaptor.getBase(),
258+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
259+
newMask->getResult(0), newPassThru);
260+
261+
// Setting the part that originally was not effectively loaded from memory
262+
// to pass through.
263+
auto bitCast =
264+
rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
265+
auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
266+
op.getPassThru());
267+
rewriter.replaceOp(op, select->getResult(0));
268+
269+
return success();
270+
}
271+
};
272+
106273
//===----------------------------------------------------------------------===//
107274
// ConvertVectorTransferRead
108275
//===----------------------------------------------------------------------===//
@@ -588,8 +755,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
588755
RewritePatternSet &patterns) {
589756

590757
// Populate `vector.*` conversion patterns.
591-
patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
592-
typeConverter, patterns.getContext());
758+
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
759+
ConvertVectorTransferRead>(typeConverter, patterns.getContext());
593760
}
594761

595762
void vector::populateVectorNarrowTypeRewritePatterns(

0 commit comments

Comments
 (0)