Skip to content

Commit d88293d

Browse files
authored
[mlir][vector] Disable BreakDownVectorBitCast for scalable vectors (#122725)
`BreakDownVectorBitCast` leverages * `vector.extract_strided_slices` + `vector.insert_strided_slices` As these Ops do not support extracting scalable sub-vectors (i.e. extracting/inserting a fraction of a scalable dim), it's best to bail out.
1 parent ba6774f commit d88293d

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,13 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
906906
VectorType castDstType = bitcastOp.getResultVectorType();
907907
assert(castSrcType.getRank() == castDstType.getRank());
908908

909+
// This transformation builds on top of
910+
// vector.{extract|insert}_strided_slice, which do not support
911+
// extracting/inserting "scallable sub-vectors". Bail out.
912+
if (castSrcType.isScalable())
913+
return rewriter.notifyMatchFailure(bitcastOp,
914+
"Scalable vectors are not supported");
915+
909916
// Only support rank 1 case for now.
910917
if (castSrcType.getRank() != 1)
911918
return failure();

mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,14 @@ func.func @bitcast_i8_to_i32(%input: vector<16xi8>) -> vector<4xi32> {
3939
// CHECK: %[[CAST3:.+]] = vector.bitcast %[[EXTRACT3]] : vector<4xi8> to vector<1xi32>
4040
// CHECK: %[[INSERT3:.+]] = vector.insert_strided_slice %[[CAST3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
4141
// CHECK: return %[[INSERT3]]
42+
43+
// -----
44+
45+
// Scalable vectors are not supported!
46+
47+
// CHECK-LABEL: func.func @bitcast_scalable_negative
48+
// CHECK: vector.bitcast
49+
func.func @bitcast_scalable_negative(%input: vector<[8]xf16>) -> vector<[4]xf32> {
50+
%0 = vector.bitcast %input : vector<[8]xf16> to vector<[4]xf32>
51+
return %0: vector<[4]xf32>
52+
}

0 commit comments

Comments
 (0)