Skip to content

Commit 1edfb4b

Browse files
[mlir][Linalg] Allow linalg.copy to be vectorized with masking
Differential Revision: https://reviews.llvm.org/D148095
1 parent 380b6a1 commit 1edfb4b

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
12821282

12831283
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
12841284
// TODO: Masking only supports dynamic generic ops for now.
1285-
if (!isa<linalg::GenericOp, linalg::FillOp>(op))
1285+
if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp>(op))
12861286
return failure();
12871287

12881288
LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2737,3 +2737,23 @@ transform.sequence failures(propagate) {
27372737
transform.structured.masked_vectorize %0 vector_sizes [8, 16]
27382738
}
27392739

2740+
// -----
2741+
2742+
// CHECK-LABEL: func @test_masked_vectorize_linalg_copy
2743+
func.func @test_masked_vectorize_linalg_copy(%A : memref<?x?xf32>, %B : memref<?x?xf32>) {
2744+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
2745+
// CHECK: %[[d0:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?x?xf32>
2746+
// CHECK: %[[c1:.*]] = arith.constant 1 : index
2747+
// CHECK: %[[d1:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
2748+
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
2749+
// CHECK: vector.mask %[[mask]] {{.*}} vector.transfer_read %{{.*}} {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32> } : vector<2x4xi1> -> vector<2x4xf32>
2750+
// CHECK: vector.mask %[[mask]] {{.*}} vector.transfer_write %{{.*}} {in_bounds = [true, true]} : vector<2x4xf32>, memref<?x?xf32> } : vector<2x4xi1>
2751+
linalg.copy ins(%A : memref<?x?xf32>) outs(%B : memref<?x?xf32>)
2752+
return
2753+
}
2754+
2755+
transform.sequence failures(propagate) {
2756+
^bb1(%arg1: !pdl.operation):
2757+
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!pdl.operation) -> !pdl.operation
2758+
transform.structured.masked_vectorize %0 vector_sizes [2, 4]
2759+
}

0 commit comments

Comments
 (0)