Skip to content

Commit 2755c69

Browse files
authored
[mlir][linalg] Vectorize unpack op without masking (llvm#89067)
Enables vectorization of unpack op in the case of unknown vector size. The vector sizes are determined by the result's shape.
1 parent 6c4dedd commit 2755c69

File tree

2 files changed

+145
-32
lines changed

2 files changed

+145
-32
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,27 +1414,39 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14141414
/// create an empty destination tensor and create a TransferWriteOp from the
14151415
/// input to the empty tensor. If the destination shape is not the same as the
14161416
/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1417-
/// mask for the write.
1417+
/// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1418+
/// inBounds attribute of the transfer write op instead of masking.
14181419
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
14191420
Value input,
14201421
SmallVector<OpFoldResult> destSizes,
1421-
ArrayRef<int64_t> inputVectorSizes) {
1422+
ArrayRef<int64_t> inputVectorSizes,
1423+
bool useInBoundsInsteadOfMasking) {
1424+
14221425
auto inputType = cast<VectorType>(input.getType());
14231426
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
14241427
inputType.getElementType());
14251428
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
14261429
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1430+
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1431+
SmallVector<bool> inBoundsVal(rank, true);
1432+
if (useInBoundsInsteadOfMasking) {
1433+
// Update the inBounds attribute.
1434+
for (unsigned i = 0; i < rank; i++)
1435+
inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1436+
!ShapedType::isDynamic(destShape[i]);
1437+
}
14271438
Operation *write = builder.create<vector::TransferWriteOp>(
14281439
loc,
14291440
/*vector=*/input,
14301441
/*source=*/dest,
14311442
/*indices=*/SmallVector<Value>(rank, zero),
1432-
/*inBounds=*/SmallVector<bool>(rank, true));
1433-
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1443+
/*inBounds=*/inBoundsVal);
14341444
assert(llvm::none_of(
14351445
destShape.drop_front(inputVectorSizes.size()),
14361446
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
14371447
"Only dims aligned with inputVectorSizes may be dynamic");
1448+
if (useInBoundsInsteadOfMasking)
1449+
return write;
14381450
bool needMaskForWrite = !llvm::equal(
14391451
inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
14401452
if (needMaskForWrite) {
@@ -1535,9 +1547,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15351547
loc, shapeCastOp.getResult(), destPermutation);
15361548

15371549
// Create TransferWriteOp.
1538-
Operation *write =
1539-
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
1540-
reifiedReturnShapes[0], inputVectorSizes);
1550+
Operation *write = createWriteOrMaskedWrite(
1551+
rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1552+
inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
15411553
newResults.push_back(write->getResult(0));
15421554
return success();
15431555
}
@@ -1547,7 +1559,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15471559
/// vector::TransposeOp - Transpose the Source tensor
15481560
/// ShapeCastOp - Reshape the data based on the target.
15491561
/// vector::TransferWriteOp. - Write the result vector back to the destination
1550-
/// tensor
1562+
/// tensor.
1563+
/// If the vector sizes are not provided:
1564+
/// * the vector sizes are determined by the input operand and attributes,
1565+
/// * update the inBounds attribute instead of masking.
15511566
static LogicalResult
15521567
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
15531568
ArrayRef<int64_t> inputVectorSizes,
@@ -1560,40 +1575,65 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
15601575

15611576
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
15621577
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1563-
1564-
SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
1565-
inputVectorSizes.end());
1566-
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
15671578
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1579+
bool useInBoundsInsteadOfMasking = false;
1580+
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1581+
1582+
auto destSize = unpackOp.getDestRank();
1583+
1584+
if (!inputVectorSizes.empty())
1585+
assert(inputVectorSizes.size() == destSize &&
1586+
"Incorrect number of input vector sizes");
15681587

1569-
// ReadMask is the size of tensor used to read and apply mask. It is
1588+
// vectorSizes is the shape of the vector that will be used to do final
1589+
// write on the destination tensor. It is set like this: Let's say the
1590+
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1591+
// Thus:
1592+
// 1. vectorSizes = sourceShape.take_front(N)
1593+
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1594+
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1595+
// innerTiles attribute value.
1596+
SmallVector<int64_t> vectorSizes(inputVectorSizes);
1597+
if (vectorSizes.empty()) {
1598+
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1599+
if (!outerDimsPerm.empty())
1600+
applyPermutationToVector(vectorSizes, outerDimsPerm);
1601+
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1602+
vectorSizes[pos] *= innerTiles[i];
1603+
1604+
useInBoundsInsteadOfMasking = true;
1605+
}
1606+
1607+
// readVectorSizes is the size of tensor used to read and apply mask. It is
15701608
// set like this: Let's say the vectorSize (VS) array is size 'N' and
15711609
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
15721610
// size M-N
15731611
// Thus:
1574-
// - initially: ReadMaskShape = vectorInputSizes
1612+
// - initially: readVectorSizes = vectorInputSizes
15751613
// - Divide all the readMaskShape locations pointed by innerDimPos
15761614
// by the innerTileSize attribute value.
1577-
// - if outer_dims_perms is present: do that permutation on readMaskShape.
1615+
// - if outer_dims_perms is present: do that permutation on readVectorSizes.
15781616
// - Append the remaining shape from SS
15791617
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
15801618
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
15811619
// 128] and outer_dims_perm is [1, 0] then read shape is:
1582-
// ReadMaskShape(initial): [512, 128]
1620+
// ReadVectorSizes(initial): [512, 128]
15831621
// Final Value(after innerDim Adjustment): [512/32, 128/16]
15841622
// = [16, 8]
15851623
// After applying outer_dims_perm: [8, 16]
15861624
// After appending the rest of the sourceShape: [8, 16, 32, 16]
15871625

1626+
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1627+
15881628
for (auto [index, size] : enumerate(innerTiles)) {
1589-
readMaskShape[innerDimPos[index]] =
1590-
llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
1629+
readVectorSizes[innerDimPos[index]] =
1630+
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
15911631
}
15921632
if (!outerDimsPerm.empty()) {
1593-
applyPermutationToVector(readMaskShape, outerDimsPerm);
1633+
applyPermutationToVector(readVectorSizes, outerDimsPerm);
15941634
}
1595-
readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
1596-
sourceShape.end());
1635+
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1636+
sourceShape.end());
15971637

15981638
ReifiedRankedShapedTypeDims reifiedRetShapes;
15991639
LogicalResult status =
@@ -1611,8 +1651,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16111651
// Read result, mask if necessary. If transferReadOp shape is not equal
16121652
// to shape of source, then a mask is necessary.
16131653
Value readResult = vector::createReadOrMaskedRead(
1614-
rewriter, loc, unpackOp.getSource(),
1615-
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue,
1654+
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
16161655
/*useInBoundsInsteadOfMasking=*/false);
16171656

16181657
PackingMetadata packMetadata;
@@ -1636,15 +1675,15 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16361675
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
16371676
loc, vecCollapsedType, transposeOp->getResult(0));
16381677

1639-
// WriteMaskShape had to match the shapecast shape for dynamic sizes,
1678+
// writeVectorSizes had to match the shapecast shape for dynamic sizes,
16401679
// otherwise the validator complains that the mask size is invalid.
1641-
SmallVector<int64_t> writeMaskShape(
1680+
SmallVector<int64_t> writeVectorSizes(
16421681
unpackOp.getDestType().hasStaticShape()
1643-
? inputVectorSizes
1682+
? vectorSizes
16441683
: shapeCastOp.getResultVectorType().getShape());
1645-
Operation *write =
1646-
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
1647-
reifiedRetShapes[0], writeMaskShape);
1684+
Operation *write = createWriteOrMaskedWrite(
1685+
rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1686+
writeVectorSizes, useInBoundsInsteadOfMasking);
16481687
newResults.push_back(write->getResult(0));
16491688
return success();
16501689
}
@@ -1673,7 +1712,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
16731712
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
16741713
/*useInBoundsInsteadOfMasking=*/false);
16751714
Operation *write = createWriteOrMaskedWrite(
1676-
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
1715+
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1716+
/*useInBoundsInsteadOfMasking=*/false);
16771717
newResults.push_back(write->getResult(0));
16781718
return success();
16791719
}
@@ -1755,8 +1795,11 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
17551795
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
17561796
return failure();
17571797
}
1758-
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1759-
if (!inputVectorSizes.empty() &&
1798+
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1799+
bool satisfyEmptyCond = inputVectorSizes.empty() &&
1800+
unpackOp.getDestType().hasStaticShape() &&
1801+
unpackOp.getSourceType().hasStaticShape();
1802+
if (!satisfyEmptyCond &&
17601803
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
17611804
return failure();
17621805

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,3 +985,73 @@ module attributes {transform.with_named_sequence} {
985985
transform.yield
986986
}
987987
}
988+
989+
// -----
990+
991+
func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
992+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
993+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
994+
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
995+
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
996+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
997+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
998+
// CHECK: %[[C00:.*]] = arith.constant 0 : index
999+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
1000+
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
1001+
%0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
1002+
return %0 : tensor<256x128xf32>
1003+
}
1004+
module attributes {transform.with_named_sequence} {
1005+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1006+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1007+
transform.structured.vectorize %0 : !transform.any_op
1008+
transform.yield
1009+
}
1010+
}
1011+
1012+
// -----
1013+
1014+
func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> {
1015+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1016+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1017+
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
1018+
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
1019+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x16x8x16xf32> to vector<64x128xf32>
1020+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x127xf32>
1021+
// CHECK: %[[C00:.*]] = arith.constant 0 : index
1022+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[EMPT]]{{\[}}%[[C00]], %[[C00]]]
1023+
// CHECK-SAME: {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32>
1024+
// CHECK: return %[[WRIT]] : tensor<64x127xf32>
1025+
%0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32>
1026+
return %0 : tensor<64x127xf32>
1027+
}
1028+
module attributes {transform.with_named_sequence} {
1029+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1030+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1031+
transform.structured.vectorize %0 : !transform.any_op
1032+
transform.yield
1033+
}
1034+
}
1035+
1036+
// -----
1037+
1038+
func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf32>, %dest: tensor<7x16xf32>) -> tensor<7x16xf32> {
1039+
%0 = tensor.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32>
1040+
return %0 : tensor<7x16xf32>
1041+
}
1042+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1043+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1044+
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
1045+
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 0, 2] : vector<4x7x4xf32> to vector<7x4x4xf32>
1046+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<7x4x4xf32> to vector<7x16xf32>
1047+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<7x16xf32>
1048+
// CHECK: %[[C00:.*]] = arith.constant 0 : index
1049+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<7x16xf32>, tensor<7x16xf32>
1050+
// CHECK: return %[[WRIT]] : tensor<7x16xf32>
1051+
module attributes {transform.with_named_sequence} {
1052+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1053+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1054+
transform.structured.vectorize %0 : !transform.any_op
1055+
transform.yield
1056+
}
1057+
}

0 commit comments

Comments
 (0)