Skip to content

Commit 98c73d5

Browse files
authored
[mlir][vector] Restrict vector.shape_cast (scalable vectors) (llvm#100331)
Updates the verifier for `vector.shape_cast` so that incorrect cases where "scalability" is dropped are immediately rejected. For example: ```mlir vector.shape_cast %vec : vector<1x1x[4]xindex> to vector<4xindex> ``` Also, as a separate PR, I've prepared a fix for the Linalg vectorizer to avoid generating such shape casts (*): * llvm#100325 (*) Note, that's just one specific case that I've identified so far.
1 parent c7a3346 commit 98c73d5

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,11 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
11681168
return !llvm::is_contained(getScalableDims(), false);
11691169
}
11701170

1171+
/// Get the number of scalable dimensions.
1172+
int64_t getNumScalableDims() const {
1173+
return llvm::count(getScalableDims(), true);
1174+
}
1175+
11711176
/// Get or create a new VectorType with the same shape as `this` and an
11721177
/// element type of bitwidth scaled by `scale`.
11731178
/// Return null if the scaled element type cannot be represented.

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5238,6 +5238,16 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
52385238
if (!isValidShapeCast(resultShape, sourceShape))
52395239
return op->emitOpError("invalid shape cast");
52405240
}
5241+
5242+
// Check that (non-)scalability is preserved
5243+
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5244+
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5245+
if (sourceNScalableDims != resultNScalableDims)
5246+
return op->emitOpError("different number of scalable dims at source (")
5247+
<< sourceNScalableDims << ") and result (" << resultNScalableDims
5248+
<< ")";
5249+
sourceVectorType.getNumDynamicDims();
5250+
52415251
return success();
52425252
}
52435253

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,20 @@ func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
11821182

11831183
// -----
11841184

1185+
func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
1186+
// expected-error@+1 {{different number of scalable dims at source (1) and result (0)}}
1187+
%0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
1188+
}
1189+
1190+
// -----
1191+
1192+
func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<2x[15]x[2]xf32>) {
1193+
// expected-error@+1 {{different number of scalable dims at source (2) and result (1)}}
1194+
%0 = vector.shape_cast %arg0 : vector<2x[15]x[2]xf32> to vector<30x[2]xf32>
1195+
}
1196+
1197+
// -----
1198+
11851199
func.func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
11861200
// expected-error@+1 {{'vector.bitcast' invalid kind of type specified}}
11871201
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32

0 commit comments

Comments
 (0)