Skip to content

Commit 01fbc56

Browse files
authored
[mlir][vector] Add support for linearizing Insert VectorOp in VectorLinearize (#92370)
Building on top of [#88204](#88204), this PR adds support for converting `vector.insert` into an equivalent `vector.shuffle` operation that operates on linearized (1-D) vectors.
1 parent 1da52ca commit 01fbc56

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
4444
return true;
4545
}
4646

47+
static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
48+
VectorType vecType = dyn_cast<VectorType>(t);
49+
// Reject index since getElementTypeBitWidth will abort for Index types.
50+
if (!vecType || vecType.getElementType().isIndex())
51+
return false;
52+
// There are no dimension to fold if it is a 0-D vector.
53+
if (vecType.getRank() == 0)
54+
return false;
55+
unsigned trailingVecDimBitWidth =
56+
vecType.getShape().back() * vecType.getElementTypeBitWidth();
57+
return trailingVecDimBitWidth <= targetBitWidth;
58+
}
59+
4760
namespace {
4861
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
4962
using OpConversionPattern::OpConversionPattern;
@@ -355,6 +368,88 @@ struct LinearizeVectorExtract final
355368
return success();
356369
}
357370

371+
private:
372+
unsigned targetVectorBitWidth;
373+
};
374+
375+
/// This pattern converts the InsertOp to a ShuffleOp that works on a
376+
/// linearized vector.
377+
/// Following,
378+
/// vector.insert %source %destination [ position ]
379+
/// is converted to :
380+
/// %source_1d = vector.shape_cast %source
381+
/// %destination_1d = vector.shape_cast %destination
382+
/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
383+
/// ] %out_nd = vector.shape_cast %out_1d
384+
/// `shuffle_indices_1d` is computed using the position of the original insert.
385+
struct LinearizeVectorInsert final
386+
: public OpConversionPattern<vector::InsertOp> {
387+
using OpConversionPattern::OpConversionPattern;
388+
LinearizeVectorInsert(
389+
const TypeConverter &typeConverter, MLIRContext *context,
390+
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
391+
PatternBenefit benefit = 1)
392+
: OpConversionPattern(typeConverter, context, benefit),
393+
targetVectorBitWidth(targetVectBitWidth) {}
394+
LogicalResult
395+
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
396+
ConversionPatternRewriter &rewriter) const override {
397+
Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
398+
assert(!(insertOp.getDestVectorType().isScalable() ||
399+
cast<VectorType>(dstTy).isScalable()) &&
400+
"scalable vectors are not supported.");
401+
402+
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
403+
targetVectorBitWidth))
404+
return rewriter.notifyMatchFailure(
405+
insertOp, "Can't flatten since targetBitWidth < OpSize");
406+
407+
// dynamic position is not supported
408+
if (insertOp.hasDynamicPosition())
409+
return rewriter.notifyMatchFailure(insertOp,
410+
"dynamic position is not supported.");
411+
auto srcTy = insertOp.getSourceType();
412+
auto srcAsVec = dyn_cast<VectorType>(srcTy);
413+
uint64_t srcSize = 0;
414+
if (srcAsVec) {
415+
srcSize = srcAsVec.getNumElements();
416+
} else {
417+
return rewriter.notifyMatchFailure(insertOp,
418+
"scalars are not supported.");
419+
}
420+
421+
auto dstShape = insertOp.getDestVectorType().getShape();
422+
const auto dstSize = insertOp.getDestVectorType().getNumElements();
423+
auto dstSizeForOffsets = dstSize;
424+
425+
// compute linearized offset
426+
int64_t linearizedOffset = 0;
427+
auto offsetsNd = insertOp.getStaticPosition();
428+
for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
429+
dstSizeForOffsets /= dstShape[dim];
430+
linearizedOffset += offset * dstSizeForOffsets;
431+
}
432+
433+
llvm::SmallVector<int64_t, 2> indices(dstSize);
434+
auto origValsUntil = indices.begin();
435+
std::advance(origValsUntil, linearizedOffset);
436+
std::iota(indices.begin(), origValsUntil,
437+
0); // original values that remain [0, offset)
438+
auto newValsUntil = origValsUntil;
439+
std::advance(newValsUntil, srcSize);
440+
std::iota(origValsUntil, newValsUntil,
441+
dstSize); // new values [offset, offset+srcNumElements)
442+
std::iota(newValsUntil, indices.end(),
443+
linearizedOffset + srcSize); // the rest of original values
444+
// [offset+srcNumElements, end)
445+
446+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
447+
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
448+
rewriter.getI64ArrayAttr(indices));
449+
450+
return success();
451+
}
452+
358453
private:
359454
unsigned targetVectorBitWidth;
360455
};
@@ -410,6 +505,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
410505
: true;
411506
});
412507
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
413-
LinearizeVectorExtractStridedSlice>(
508+
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
414509
typeConverter, patterns.getContext(), targetBitWidth);
415510
}

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
245245
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
246246
return %0 : vector<8x2xf32>
247247
}
248+
249+
// -----
250+
// ALL-LABEL: test_vector_insert
251+
// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
252+
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
253+
// DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
254+
// DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
255+
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
256+
// DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
257+
// DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
258+
// DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
259+
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
260+
// DEFAULT: return %[[RES]] : vector<2x8x4xf32>
261+
262+
// BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
263+
// BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
264+
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
265+
// BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
266+
// BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
267+
// BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
268+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
269+
// BW-128: return %[[RES]] : vector<2x8x4xf32>
270+
271+
// BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32>
272+
// BW-0: return %[[RES]] : vector<2x8x4xf32>
273+
274+
%0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
275+
return %0 : vector<2x8x4xf32>
276+
}

0 commit comments

Comments
 (0)