Skip to content

Commit 665995b

Browse files
committed
[mlir][Conversion] Allow lowering to fixed arrays of scalable vectors
This allows lowering vector types like: vector<3x[4]> or vector<3x2x[4]> to LLVM IR, i.e. vectors where the trailing dim is scalable. This is contingent on: https://discourse.llvm.org/t/rfc-enable-arrays-of-scalable-vector-types/72935 More tests will be added in later patches, however, some MLIR fixes are needed first. Depends on: D158517 Reviewed By: awarzynski Differential Revision: https://reviews.llvm.org/D158752
1 parent 4eafb5f commit 665995b

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,8 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
507507
type.getScalableDims().back());
508508
assert(LLVM::isCompatibleVectorType(vectorType) &&
509509
"expected vector type compatible with the LLVM dialect");
510-
if (type.isScalable() && (type.getRank() > 1))
510+
// Only the trailing dimension can be scalable.
511+
if (llvm::is_contained(type.getScalableDims().drop_back(), true))
511512
return failure();
512513
auto shape = type.getShape();
513514
for (int i = shape.size() - 2; i >= 0; --i)

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,3 +2260,13 @@ func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> {
22602260
%0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
22612261
return %0 : vector<8xf32>
22622262
}
2263+
2264+
// -----
2265+
2266+
// CHECK-LABEL: @make_fixed_vector_of_scalable_vector
2267+
func.func @make_fixed_vector_of_scalable_vector(%f : f64) -> vector<3x[2]xf64>
2268+
{
2269+
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.array<3 x vector<[2]xf64>>
2270+
%res = vector.broadcast %f : f64 to vector<3x[2]xf64>
2271+
return %res : vector<3x[2]xf64>
2272+
}

0 commit comments

Comments
 (0)