Skip to content

Commit dc5d541

Browse files
authored
[mlir][vector] Support scalable vectors when unrolling vector.bitcast (#94197)
Follow up to #94064.
1 parent 747f9da commit dc5d541

File tree

3 files changed

+44
-12
lines changed

3 files changed

+44
-12
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ class TileOffsetRangeImpl {
287287
return getDynamicTileOffsets(linearIndex);
288288
}
289289

290+
size_t getRank() const { return tileShape.size(); }
291+
290292
private:
291293
/// The sub-shape that divides the larger outer shape (which is provided to
292294
/// the constructor).
@@ -388,6 +390,9 @@ class StaticTileOffsetRange {
388390
/// Returns the total number of tiles that fit in the larger shape.
389391
size_t size() const { return params.getMaxLinearIndex(); }
390392

393+
/// Returns rank of the iterator's shape.
394+
size_t getRank() const { return params.getRank(); }
395+
391396
private:
392397
const ParamsTy params;
393398
IteratorTy beginValue;

mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,12 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
5656
if (!unrollIterator)
5757
return failure();
5858

59-
// TODO: Support the scalable vector cases. It is not supported because
60-
// the final rank could be values other than `targetRank`. It makes creating
61-
// the result type of new vector.bitcast ops much harder.
62-
if (resultType.isScalable()) {
63-
return rewriter.notifyMatchFailure(op,
64-
"unrolling vector.bitcast on scalable "
65-
"vectors is not yet implemented");
66-
}
67-
68-
ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank);
69-
auto bitcastResType = VectorType::get(shape, resultType.getElementType());
59+
auto unrollRank = unrollIterator->getRank();
60+
ArrayRef<int64_t> shape = resultType.getShape().drop_front(unrollRank);
61+
ArrayRef<bool> scalableDims =
62+
resultType.getScalableDims().drop_front(unrollRank);
63+
auto bitcastResType =
64+
VectorType::get(shape, resultType.getElementType(), scalableDims);
7065

7166
Location loc = op.getLoc();
7267
Value result = rewriter.create<arith::ConstantOp>(

mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,39 @@ func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) ->
3838
return %0 : vector<1x2x[3]x8xi32>
3939
}
4040
// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
41-
// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
41+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
42+
// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<1x2x[3]x8xi32>
43+
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0, 0] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
44+
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[3]x4xi64> to vector<[3]x8xi32>
45+
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0, 0] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
46+
// CHECK: %[[V2:.+]] = vector.extract %[[IN]][0, 1] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
47+
// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[3]x4xi64> to vector<[3]x8xi32>
48+
// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [0, 1] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
49+
// CHECK: return %[[R2]] : vector<1x2x[3]x8xi32>
50+
51+
func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) -> vector<2x[4]xi32> {
52+
%0 = vector.bitcast %arg0 : vector<2x[2]xi64> to vector<2x[4]xi32>
53+
return %0 : vector<2x[4]xi32>
54+
}
55+
// CHECK-LABEL: func.func @vector_bitcast_2d_trailing_scalable_dim
56+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
57+
// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<2x[4]xi32>
58+
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<[2]xi64> from vector<2x[2]xi64>
59+
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[2]xi64> to vector<[4]xi32>
60+
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] : vector<[4]xi32> into vector<2x[4]xi32>
61+
// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<[2]xi64> from vector<2x[2]xi64>
62+
// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[2]xi64> to vector<[4]xi32>
63+
// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1] : vector<[4]xi32> into vector<2x[4]xi32>
64+
// CHECK: return %[[R2]] : vector<2x[4]xi32>
65+
66+
func.func @negative_vector_bitcast_2d_leading_scalable_dim(%arg0: vector<[2]x2xi64>) -> vector<[2]x4xi32>
67+
{
68+
%0 = vector.bitcast %arg0 : vector<[2]x2xi64> to vector<[2]x4xi32>
69+
return %0 : vector<[2]x4xi32>
70+
}
71+
// CHECK-LABEL: func.func @negative_vector_bitcast_2d_leading_scalable_dim
72+
// CHECK-NOT: vector.extract
73+
// CHECK-NOT: vector.insert
4274

4375
module attributes {transform.with_named_sequence} {
4476
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {

0 commit comments

Comments
 (0)