Skip to content

Commit 8e7d415

Browse files
committed
[mlir] Vectorize unpack op given no vector sizes
Enables vectorization of unpack op in the case of unknown vector size. The vector sizes are determined by the result shape.
1 parent 06eedff commit 8e7d415

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,16 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
15971597

15981598
RankedTensorType unpackTensorType = unpackOp.getSourceType();
15991599

1600+
// If the input vector sizes are not provided, then the vector sizes are
1601+
// determined by the result tensor shape. In case the vector sizes aren't
1602+
// provided, we update the inBounds attribute instead of masking.
1603+
bool doMasking = true;
1604+
if (inputVectorSizes.empty()) {
1605+
ArrayRef<int64_t> resultTensorShape = unpackOp.getDestType().getShape();
1606+
inputVectorSizes = resultTensorShape.take_front(unpackOp.getSourceRank());
1607+
doMasking = false;
1608+
}
1609+
16001610
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
16011611
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
16021612

@@ -1651,7 +1661,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16511661
// to shape of source, then a mask is necessary.
16521662
Value readResult = createReadOrMaskedRead(
16531663
rewriter, loc, unpackOp.getSource(),
1654-
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
1664+
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue,
1665+
doMasking);
16551666

16561667
PackingMetadata packMetadata;
16571668
SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -1827,8 +1838,14 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
18271838
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
18281839
return failure();
18291840
}
1830-
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1831-
if (!inputVectorSizes.empty() &&
1841+
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1842+
bool satisfyEmptyCond = true;
1843+
if (inputVectorSizes.empty()) {
1844+
if (!unpackOp.getDestType().hasStaticShape() ||
1845+
!unpackOp.getSourceType().hasStaticShape())
1846+
satisfyEmptyCond = false;
1847+
}
1848+
if (!satisfyEmptyCond &&
18321849
failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
18331850
return failure();
18341851

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,3 +985,26 @@ module attributes {transform.with_named_sequence} {
985985
transform.yield
986986
}
987987
}
988+
989+
// -----
990+
991+
func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
992+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
993+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
994+
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
995+
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
996+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
997+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
998+
// CHECK: %[[C00:.*]] = arith.constant 0 : index
999+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
1000+
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
1001+
%0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
1002+
return %0 : tensor<256x128xf32>
1003+
}
1004+
module attributes {transform.with_named_sequence} {
1005+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1006+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1007+
transform.structured.vectorize %0 : !transform.any_op
1008+
transform.yield
1009+
}
1010+
}

0 commit comments

Comments
 (0)