Skip to content

Commit 5bc4819

Browse files
committed
[mlir] Vectorize unpack op given no vector sizes
Enables vectorization of unpack op in the case of unknown vector size. The vector sizes are determined by the result shape.
1 parent 03b1a0c commit 5bc4819

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,16 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
15581558

15591559
RankedTensorType unpackTensorType = unpackOp.getSourceType();
15601560

1561+
// If the input vector sizes are not provided, then the vector sizes are
1562+
// determined by the result tensor shape. In case the vector sizes aren't
1563+
// provided, we update the inBounds attribute instead of masking.
1564+
bool useInBoundsInsteadOfMasking = true;
1565+
if (inputVectorSizes.empty()) {
1566+
ArrayRef<int64_t> resultTensorShape = unpackOp.getDestType().getShape();
1567+
inputVectorSizes = resultTensorShape.take_front(unpackOp.getSourceRank());
1568+
useInBoundsInsteadOfMasking = false;
1569+
}
1570+
15611571
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
15621572
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
15631573

@@ -1612,7 +1622,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16121622
// to shape of source, then a mask is necessary.
16131623
Value readResult = vector::createReadOrMaskedRead(
16141624
rewriter, loc, unpackOp.getSource(),
1615-
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
1625+
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue,
1626+
doMasking);
16161627

16171628
PackingMetadata packMetadata;
16181629
SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -1753,8 +1764,14 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
17531764
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
17541765
return failure();
17551766
}
1756-
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1757-
if (!inputVectorSizes.empty() &&
1767+
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1768+
bool satisfyEmptyCond = true;
1769+
if (inputVectorSizes.empty()) {
1770+
if (!unpackOp.getDestType().hasStaticShape() ||
1771+
!unpackOp.getSourceType().hasStaticShape())
1772+
satisfyEmptyCond = false;
1773+
}
1774+
if (!satisfyEmptyCond &&
17581775
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
17591776
return failure();
17601777

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,3 +985,26 @@ module attributes {transform.with_named_sequence} {
985985
transform.yield
986986
}
987987
}
988+
989+
// -----
990+
991+
func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
992+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
993+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
994+
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
995+
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
996+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
997+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
998+
// CHECK: %[[C00:.*]] = arith.constant 0 : index
999+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
1000+
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
1001+
%0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
1002+
return %0 : tensor<256x128xf32>
1003+
}
1004+
module attributes {transform.with_named_sequence} {
1005+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1006+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1007+
transform.structured.vectorize %0 : !transform.any_op
1008+
transform.yield
1009+
}
1010+
}

0 commit comments

Comments
 (0)