Skip to content

Commit c776290

Browse files
committed
add dynamic test
1 parent 70d3705 commit c776290

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,16 +1442,16 @@ getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
14421442
return tiledIdx;
14431443
};
14441444
SmallVector<int64_t> perm;
1445-
for (int i = 0; i < packOp.getDestRank(); i++)
1445+
for (size_t i = 0; i < packOp.getDestRank(); i++)
14461446
perm.push_back(packedIdxToTiledIdx(i));
14471447
return perm;
14481448
}
14491449

14501450
/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
14511451
/// above in `getTiledShapeToPackedShapePerm`.
1452-
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
1452+
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1453+
ArrayRef<int64_t> destShape) {
14531454
auto perm = getTiledShapeToPackedShapePerm(packOp);
1454-
auto destShape = packOp.getDestType().getShape();
14551455
return applyPermutation(destShape, invertPermutationVector(perm));
14561456
}
14571457

@@ -1556,7 +1556,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15561556
inputShape, padValue);
15571557

15581558
// Create ShapeCastOp
1559-
auto tiledPackType = VectorType::get(getTiledPackShape(packOp),
1559+
SmallVector<int64_t> destShape(inputVectorSizes);
1560+
destShape.append(innerTiles.begin(), innerTiles.end());
1561+
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
15601562
packOp.getDestType().getElementType());
15611563
auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
15621564
loc, tiledPackType, maskedOp->getResult(0));

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,46 @@ module attributes {transform.with_named_sequence} {
566566

567567
// -----
568568

569+
func.func @test_vectorize_dynamic_result_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
570+
%pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
571+
return %pack : tensor<?x?x16x2xf32>
572+
}
573+
module attributes {transform.with_named_sequence} {
574+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
575+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
576+
transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
577+
transform.yield
578+
}
579+
}
580+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
581+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
582+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
583+
// CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?x16x2xf32>
584+
// CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?x16x2xf32>
585+
// CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index
586+
// CHECK-DAG: %[[c1_0:.*]] = arith.constant 1 : index
587+
// CHECK-DAG: %[[d0_0:.*]] = tensor.dim {{.*}} %[[c0_0]] : tensor<?x?xf32>
588+
// CHECK-DAG: %[[d1_0:.*]] = tensor.dim {{.*}} %[[c1_0]] : tensor<?x?xf32>
589+
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0_0]], %[[d1_0]] : vector<8x16xi1>
590+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
591+
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
592+
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
593+
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
594+
// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
595+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
596+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
597+
// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
598+
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
599+
// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
600+
// CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[d0]], %[[d1]]) : tensor<?x?x16x2xf32>
601+
// CHECK: %[[mask_0:.*]] = vector.create_mask %[[d0]], %[[d1]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
602+
// CHECK: %[[masked_write:.*]] = vector.mask %[[mask_0]] {
603+
// CHECK-SAME: vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
604+
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
605+
// CHECK: return %[[masked_write]] : tensor<?x?x16x2xf32>
606+
607+
// -----
608+
569609
func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
570610
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
571611
outs(%C: memref<?x?xf32>)

0 commit comments

Comments
 (0)