Skip to content

Commit bd5d361

Browse files
authored
[mlir][vector] add support for linearizing vector.bitcast in VectorLinearize (#123110)
This PR adds support for converting Vector::BitCastOp working on ND (N >1) vectors into the same op working on linearized (1D) vectors.
1 parent 285009f commit bd5d361

File tree

2 files changed

+112
-6
lines changed

2 files changed

+112
-6
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
7272
auto resType =
7373
getTypeConverter()->convertType<VectorType>(constOp.getType());
7474

75+
if (!resType)
76+
return rewriter.notifyMatchFailure(loc, "can't convert return type");
77+
7578
if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
7679
return rewriter.notifyMatchFailure(
7780
loc,
7881
"Cannot linearize a constant scalable vector that's not a splat");
7982

80-
if (!resType)
81-
return rewriter.notifyMatchFailure(loc, "can't convert return type");
8283
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
8384
return rewriter.notifyMatchFailure(
8485
loc, "Can't flatten since targetBitWidth <= OpSize");
@@ -459,6 +460,45 @@ struct LinearizeVectorInsert final
459460
private:
460461
unsigned targetVectorBitWidth;
461462
};
463+
464+
/// This pattern converts the BitCastOp that works on nD (n > 1)
465+
/// vectors to a BitCastOp that works on linearized vectors.
466+
/// Following,
467+
/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
468+
/// is converted to :
469+
/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
470+
/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
471+
/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
472+
struct LinearizeVectorBitCast final
473+
: public OpConversionPattern<vector::BitCastOp> {
474+
using OpConversionPattern::OpConversionPattern;
475+
LinearizeVectorBitCast(
476+
const TypeConverter &typeConverter, MLIRContext *context,
477+
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
478+
PatternBenefit benefit = 1)
479+
: OpConversionPattern(typeConverter, context, benefit),
480+
targetVectorBitWidth(targetVectBitWidth) {}
481+
LogicalResult
482+
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
483+
ConversionPatternRewriter &rewriter) const override {
484+
Location loc = castOp.getLoc();
485+
auto resType = getTypeConverter()->convertType(castOp.getType());
486+
if (!resType)
487+
return rewriter.notifyMatchFailure(loc, "can't convert return type.");
488+
489+
if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
490+
return rewriter.notifyMatchFailure(
491+
loc, "Can't flatten since targetBitWidth <= OpSize");
492+
493+
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
494+
adaptor.getSource());
495+
return mlir::success();
496+
}
497+
498+
private:
499+
unsigned targetVectorBitWidth;
500+
};
501+
462502
} // namespace
463503

464504
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
@@ -485,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
485525
typeConverter.addTargetMaterialization(materializeCast);
486526
target.markUnknownOpDynamicallyLegal(
487527
[=](Operation *op) -> std::optional<bool> {
488-
if ((isa<arith::ConstantOp>(op) ||
528+
if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
489529
op->hasTrait<OpTrait::Vectorizable>())) {
490530
return (isLessThanTargetBitWidth(op, targetBitWidth)
491531
? typeConverter.isLegal(op)
@@ -494,8 +534,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
494534
return std::nullopt;
495535
});
496536

497-
patterns.add<LinearizeConstant, LinearizeVectorizable>(
498-
typeConverter, patterns.getContext(), targetBitWidth);
537+
patterns
538+
.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
539+
typeConverter, patterns.getContext(), targetBitWidth);
499540
}
500541

501542
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
179179

180180
// ALL-LABEL: func.func @test_extract_strided_slice_1_scalable(
181181
// ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
182-
func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
182+
func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
183183
// ALL-NOT: vector.shuffle
184184
// ALL-NOT: vector.shape_cast
185185
// ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
@@ -318,3 +318,68 @@ func.func @test_vector_extract_scalar() {
318318
%0 = vector.extract %cst[0] : i32 from vector<4xi32>
319319
return
320320
}
321+
322+
// -----
323+
324+
// ALL-LABEL: test_vector_bitcast
325+
// ALL-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
326+
func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
327+
// DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x4xf32> to vector<16xf32>
328+
// DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16>
329+
// DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<32xf16> to vector<4x8xf16>
330+
331+
// BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
332+
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
333+
%1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
334+
return %1 : vector<4x8xf16>
335+
}
336+
337+
// -----
338+
339+
// ALL-LABEL: test_vector_bitcast
340+
// ALL-SAME: %[[ARG_0:.*]]: vector<4x2xf32>
341+
func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
342+
// DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32>
343+
// DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
344+
// DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
345+
// BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32>
346+
// BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
347+
// BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
348+
349+
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x2xf32> to vector<4x4xf16>
350+
%1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
351+
return %1 : vector<4x4xf16>
352+
}
353+
354+
// -----
355+
356+
// ALL-LABEL: test_vector_bitcast
357+
// ALL-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32>
358+
func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
359+
// DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32>
360+
// DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
361+
// DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
362+
// BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32>
363+
// BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
364+
// BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
365+
366+
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x[2]xf32> to vector<4x[4]xf16>
367+
%1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16>
368+
return %1 : vector<4x[4]xf16>
369+
}
370+
371+
// -----
372+
// ALL-LABEL: test_vector_bitcast
373+
// ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
374+
func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
375+
// DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32>
376+
// DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
377+
// DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
378+
// BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32>
379+
// BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
380+
// BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
381+
382+
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<[4]x2xf32> to vector<[4]x4xf16>
383+
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
384+
return %1 : vector<[4]x4xf16>
385+
}

0 commit comments

Comments
 (0)