Skip to content

Commit eb5f6ee

Browse files
committed
[mlir] Vectorize unpack op given no vector sizes
Enables vectorization of unpack op in the case of unknown vector size. The vector sizes are determined by the result shape.
1 parent 2b2c66c commit eb5f6ee

File tree

2 files changed

+124
-16
lines changed

2 files changed

+124
-16
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,27 +1414,39 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14141414
/// create an empty destination tensor and create a TransferWriteOp from the
14151415
/// input to the empty tensor. If the destination shape is not the same as the
14161416
/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1417-
/// mask for the write.
1417+
/// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1418+
/// inBounds attribute of the transfer write op instead of masking.
14181419
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
14191420
Value input,
14201421
SmallVector<OpFoldResult> destSizes,
1421-
ArrayRef<int64_t> inputVectorSizes) {
1422+
ArrayRef<int64_t> inputVectorSizes,
1423+
bool useInBoundsInsteadOfMasking) {
1424+
14221425
auto inputType = cast<VectorType>(input.getType());
14231426
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
14241427
inputType.getElementType());
14251428
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
14261429
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1430+
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1431+
SmallVector<bool> inBoundsVal(rank, true);
1432+
if (useInBoundsInsteadOfMasking) {
1433+
// Update the inBounds attribute.
1434+
for (unsigned i = 0; i < rank; i++)
1435+
inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1436+
!ShapedType::isDynamic(destShape[i]);
1437+
}
14271438
Operation *write = builder.create<vector::TransferWriteOp>(
14281439
loc,
14291440
/*vector=*/input,
14301441
/*source=*/dest,
14311442
/*indices=*/SmallVector<Value>(rank, zero),
1432-
/*inBounds=*/SmallVector<bool>(rank, true));
1433-
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1443+
/*inBounds=*/inBoundsVal);
14341444
assert(llvm::none_of(
14351445
destShape.drop_front(inputVectorSizes.size()),
14361446
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
14371447
"Only dims aligned with inputVectorSizes may be dynamic");
1448+
if (useInBoundsInsteadOfMasking)
1449+
return write;
14381450
bool needMaskForWrite = !llvm::equal(
14391451
inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
14401452
if (needMaskForWrite) {
@@ -1535,9 +1547,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15351547
loc, shapeCastOp.getResult(), destPermutation);
15361548

15371549
// Create TransferWriteOp.
1538-
Operation *write =
1539-
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
1540-
reifiedReturnShapes[0], inputVectorSizes);
1550+
Operation *write = createWriteOrMaskedWrite(
1551+
rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1552+
inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
15411553
newResults.push_back(write->getResult(0));
15421554
return success();
15431555
}
@@ -1547,7 +1559,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15471559
/// vector::TransposeOp - Transpose the Source tensor
15481560
/// ShapeCastOp - Reshape the data based on the target.
15491561
/// vector::TransferWriteOp. - Write the result vector back to the destination
1550-
/// tensor
1562+
/// tensor. If the vector sizes are not provided, then the vector sizes are
1563+
/// determined by the result tensor shape. In case the vector sizes aren't
1564+
/// provided, we update the inBounds attribute instead of masking.
15511565
static LogicalResult
15521566
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
15531567
ArrayRef<int64_t> inputVectorSizes,
@@ -1560,11 +1574,32 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
15601574

15611575
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
15621576
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1577+
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1578+
bool useInBoundsInsteadOfMasking = false;
1579+
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1580+
1581+
auto destSize = unpackOp.getDestRank();
1582+
1583+
// initVectorShape is the shape of the vector that will be read from the
1584+
// source tensor. It is set like this: Let's say the sourceShape is 'M' and
1585+
// the vectorSize (VS) array is size 'N' where N <= M. Thus:
1586+
// - initVectorShape = sourceShape.take_front(N)
1587+
// - if outer_dims_perms is present: do that permutation on initVectorShape.
1588+
// - Multiply all the locations pointed by innerDimPos by the innerTileSize
1589+
// attribute value.
1590+
SmallVector<int64_t> initVectorShape{sourceShape.take_front(destSize)};
1591+
if (inputVectorSizes.empty()) {
1592+
if (!outerDimsPerm.empty())
1593+
applyPermutationToVector(initVectorShape, outerDimsPerm);
1594+
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1595+
initVectorShape[pos] *= innerTiles[i];
1596+
1597+
inputVectorSizes = initVectorShape;
1598+
useInBoundsInsteadOfMasking = true;
1599+
}
15631600

15641601
SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
15651602
inputVectorSizes.end());
1566-
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1567-
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
15681603

15691604
// ReadMask is the size of tensor used to read and apply mask. It is
15701605
// set like this: Let's say the vectorSize (VS) array is size 'N' and
@@ -1642,9 +1677,9 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16421677
unpackOp.getDestType().hasStaticShape()
16431678
? inputVectorSizes
16441679
: shapeCastOp.getResultVectorType().getShape());
1645-
Operation *write =
1646-
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
1647-
reifiedRetShapes[0], writeMaskShape);
1680+
Operation *write = createWriteOrMaskedWrite(
1681+
rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1682+
writeMaskShape, useInBoundsInsteadOfMasking);
16481683
newResults.push_back(write->getResult(0));
16491684
return success();
16501685
}
@@ -1673,7 +1708,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
16731708
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
16741709
/*useInBoundsInsteadOfMasking=*/false);
16751710
Operation *write = createWriteOrMaskedWrite(
1676-
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
1711+
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1712+
/*useInBoundsInsteadOfMasking=*/false);
16771713
newResults.push_back(write->getResult(0));
16781714
return success();
16791715
}
@@ -1755,8 +1791,11 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
17551791
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
17561792
return failure();
17571793
}
1758-
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1759-
if (!inputVectorSizes.empty() &&
1794+
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1795+
bool satisfyEmptyCond = inputVectorSizes.empty() &&
1796+
unpackOp.getDestType().hasStaticShape() &&
1797+
unpackOp.getSourceType().hasStaticShape();
1798+
if (!satisfyEmptyCond &&
17601799
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
17611800
return failure();
17621801

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,3 +985,72 @@ module attributes {transform.with_named_sequence} {
985985
transform.yield
986986
}
987987
}
988+
989+
// -----
990+
991+
func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
992+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
993+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
994+
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
995+
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
996+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
997+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
998+
// CHECK: %[[C00:.*]] = arith.constant 0 : index
999+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
1000+
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
1001+
%0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
1002+
return %0 : tensor<256x128xf32>
1003+
}
1004+
module attributes {transform.with_named_sequence} {
1005+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1006+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1007+
transform.structured.vectorize %0 : !transform.any_op
1008+
transform.yield
1009+
}
1010+
}
1011+
1012+
// -----
1013+
1014+
func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x4x16x16xf32>, %dest: tensor<63x127xf32>) -> tensor<63x127xf32> {
1015+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1016+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1017+
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
1018+
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
1019+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x16x8x16xf32> to vector<64x128xf32>
1020+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<63x127xf32>
1021+
// CHECK: %[[C00:.*]] = arith.constant 0 : index
1022+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<64x128xf32>, tensor<63x127xf32>
1023+
// CHECK: return %[[WRIT]] : tensor<63x127xf32>
1024+
%0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<63x127xf32>
1025+
return %0 : tensor<63x127xf32>
1026+
}
1027+
module attributes {transform.with_named_sequence} {
1028+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1029+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1030+
transform.structured.vectorize %0 : !transform.any_op
1031+
transform.yield
1032+
}
1033+
}
1034+
1035+
// -----
1036+
1037+
func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf32>, %dest: tensor<7x16xf32>) -> tensor<7x16xf32> {
1038+
%0 = tensor.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32>
1039+
return %0 : tensor<7x16xf32>
1040+
}
1041+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1042+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1043+
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
1044+
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 0, 2] : vector<4x7x4xf32> to vector<7x4x4xf32>
1045+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<7x4x4xf32> to vector<7x16xf32>
1046+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<7x16xf32>
1047+
// CHECK: %[[C00:.*]] = arith.constant 0 : index
1048+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<7x16xf32>, tensor<7x16xf32>
1049+
// CHECK: return %[[WRIT]] : tensor<7x16xf32>
1050+
module attributes {transform.with_named_sequence} {
1051+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1052+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1053+
transform.structured.vectorize %0 : !transform.any_op
1054+
transform.yield
1055+
}
1056+
}

0 commit comments

Comments
 (0)