Skip to content

Commit 7880b2c

Browse files
authored
[mlir] Add direct vectorization lowering for tensor.pack ops (#78660)
This PR adds a direct vectorization lowering of `tensor.pack` into `mask(vector.transfer_read)`->`vector.shape_cast`->`vector.transpose`->`vector.transfer_write`.
1 parent 347ab99 commit 7880b2c

File tree

6 files changed

+357
-70
lines changed

6 files changed

+357
-70
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ FailureOr<RankedTensorType>
3232
computeTransposedType(RankedTensorType rankedTensorType,
3333
ArrayRef<int64_t> transposeVector);
3434

35+
/// Given a tensor::PackOp, compute the permutation vector to shuffle the
36+
/// packed shape into the shape before any outer or inner permutations have
37+
/// been applied.
38+
/// i.e. for a pack from an ABCD layout to an ABCDba:
39+
/// The packed shape would be ABCDba.
40+
/// The pre-permutation shape would be AaBbCD.
41+
SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
42+
3543
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
3644
/// source tensor or inserts the source tensor into a destination tensor with
3745
/// the same shape.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3152,7 +3152,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
31523152

31533153
// TODO: Check that the correct number of vectorSizes was provided.
31543154
for (Operation *target : targets) {
3155-
if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
3155+
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
31563156
return mlir::emitSilenceableFailure(target->getLoc())
31573157
<< "Unsupported Op, cannot vectorize";
31583158
}

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -233,31 +233,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
233233
rewriter.setInsertionPoint(packOp);
234234

235235
// 2. Compute the permutation vector to shuffle packed shape into the shape
236-
// before any outer or inner permutations have been applied. The permutation
237-
// can be obtained from two permutations:
238-
// a) Compute the permutation vector to move the last `numPackedDims` into
239-
// the `innerPosDims` of a shape of rank `packedRank`.
240-
// b) Compute the permutation vector to move outer dims if the pack op
241-
// has outer_dims_perm.
242-
// Apply (b) permutation on (a) permutation to get the final permutation.
243-
int64_t numPackedDims = packOp.getInnerDimsPos().size();
244-
int64_t packedRank = packedTensorType.getRank();
245-
auto lastDims = llvm::to_vector(
246-
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
236+
// before any outer or inner permutations have been applied.
247237
PackingMetadata packingMetadata = computePackingMetadata(
248238
packedTensorType.getRank(), packOp.getInnerDimsPos());
249-
SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
250-
packedRank, lastDims, packingMetadata.insertPositions);
251-
252-
SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
253-
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
254-
if (!outerPerm.empty())
255-
applyPermutationToVector(outerPos, outerPerm);
256-
SmallVector<int64_t> outerPositionPerm = computePermutationVector(
257-
packedRank, packingMetadata.outerPositions, outerPos);
258-
259-
SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
260-
applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
239+
SmallVector<int64_t> packedToStripMinedShapePerm =
240+
tensor::getPackInverseDestPermutation(packOp);
261241

262242
// 3. Compute the stripMinedShape: this is the packed shape before any outer
263243
// or inner permutations have been applied.
@@ -304,10 +284,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
304284
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
305285
DBGS() << "packedShape: ");
306286
DBGSNL();
307-
llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
308-
DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
309-
DBGS() << "innerPositionsPerm: ");
310-
DBGSNL();
311287
llvm::interleaveComma(packedToStripMinedShapePerm,
312288
DBGS() << "packedToStripMinedShapePerm: ");
313289
DBGSNL(); llvm::interleaveComma(
@@ -332,9 +308,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
332308
auto emptyOp =
333309
rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
334310
// Offsets.
335-
SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
311+
SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
312+
rewriter.getIndexAttr(0));
336313
// Strides.
337-
SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
314+
SmallVector<OpFoldResult> ones(packOp.getDestRank(),
315+
rewriter.getIndexAttr(1));
338316
SmallVector<OpFoldResult> sizes =
339317
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
340318

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 203 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,26 @@
1919
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2020
#include "mlir/Dialect/Linalg/Utils/Utils.h"
2121
#include "mlir/Dialect/Tensor/IR/Tensor.h"
22+
#include "mlir/Dialect/Tensor/Utils/Utils.h"
23+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2224
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2325
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2426
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
2527
#include "mlir/IR/AffineExpr.h"
28+
#include "mlir/IR/Builders.h"
29+
#include "mlir/IR/BuiltinTypeInterfaces.h"
30+
#include "mlir/IR/BuiltinTypes.h"
31+
#include "mlir/IR/OpDefinition.h"
2632
#include "mlir/IR/PatternMatch.h"
2733
#include "mlir/Support/LLVM.h"
2834
#include "mlir/Transforms/RegionUtils.h"
2935
#include "llvm/ADT/STLExtras.h"
3036
#include "llvm/ADT/Sequence.h"
3137
#include "llvm/ADT/SmallVector.h"
3238
#include "llvm/ADT/TypeSwitch.h"
39+
#include "llvm/ADT/iterator_range.h"
3340
#include "llvm/Support/Debug.h"
41+
#include "llvm/Support/MathExtras.h"
3442
#include "llvm/Support/raw_ostream.h"
3543
#include <optional>
3644
#include <type_traits>
@@ -1393,6 +1401,164 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13931401
return success();
13941402
}
13951403

1404+
/// Given a tensor::PackOp, return the `dest` shape before any packing
1405+
/// permutations.
1406+
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1407+
ArrayRef<int64_t> destShape) {
1408+
return applyPermutation(destShape,
1409+
tensor::getPackInverseDestPermutation(packOp));
1410+
}
1411+
1412+
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
1413+
/// vector type for the read is not the same as the type of `source`, then a
1414+
/// mask is created on the read.
1415+
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
1416+
Value source, ArrayRef<int64_t> readShape,
1417+
Value padValue) {
1418+
assert(llvm::none_of(readShape,
1419+
[](int64_t s) { return s == ShapedType::kDynamic; }));
1420+
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
1421+
assert(sourceShape.size() == readShape.size());
1422+
auto maskType = VectorType::get(readShape, builder.getI1Type());
1423+
auto vectorType = VectorType::get(readShape, padValue.getType());
1424+
int64_t readRank = readShape.size();
1425+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1426+
auto transferReadOp = builder.create<vector::TransferReadOp>(
1427+
loc,
1428+
/*vectorType=*/vectorType,
1429+
/*source=*/source,
1430+
/*indices=*/SmallVector<Value>(readRank, zero),
1431+
/*padding=*/padValue,
1432+
/*inBounds=*/SmallVector<bool>(readRank, true));
1433+
if (llvm::equal(readShape, sourceShape)) {
1434+
return transferReadOp;
1435+
}
1436+
SmallVector<OpFoldResult> mixedSourceDims =
1437+
tensor::getMixedSizes(builder, loc, source);
1438+
Value mask =
1439+
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1440+
return mlir::vector::maskOperation(builder, transferReadOp, mask)
1441+
->getResult(0);
1442+
}
1443+
1444+
/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1445+
/// create an empty destination tensor and create a TransferWriteOp from the
1446+
/// input to the empty tensor. If the destination shape is not the same as the
1447+
/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1448+
/// mask for the write.
1449+
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1450+
Value input,
1451+
SmallVector<OpFoldResult> destSizes,
1452+
ArrayRef<int64_t> inputVectorSizes) {
1453+
auto inputType = cast<VectorType>(input.getType());
1454+
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1455+
inputType.getElementType());
1456+
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1457+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1458+
Operation *write = builder.create<vector::TransferWriteOp>(
1459+
loc,
1460+
/*vector=*/input,
1461+
/*source=*/dest,
1462+
/*indices=*/SmallVector<Value>(rank, zero),
1463+
/*inBounds=*/SmallVector<bool>(rank, true));
1464+
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1465+
assert(llvm::none_of(
1466+
destShape.drop_front(inputVectorSizes.size()),
1467+
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1468+
"Only dims aligned with inputVectorSizes may be dynamic");
1469+
bool needMaskForWrite = !llvm::equal(
1470+
inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1471+
if (needMaskForWrite) {
1472+
SmallVector<int64_t> writeMaskShape;
1473+
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1474+
writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1475+
destShape.end());
1476+
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1477+
Value maskForWrite =
1478+
builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1479+
write = mlir::vector::maskOperation(builder, write, maskForWrite);
1480+
}
1481+
return write;
1482+
}
1483+
1484+
/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1485+
/// padding value into:
1486+
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1487+
/// As in the following example:
1488+
///
1489+
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1490+
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1491+
///
1492+
/// This pack would be vectorized to:
1493+
///
1494+
/// %load = vector.mask %mask {
1495+
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1496+
/// {in_bounds = [true, true, true]} :
1497+
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
1498+
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1499+
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1500+
/// to vector<32x4x2x1x16xf32>
1501+
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1502+
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1503+
/// %write = vector.transfer_write %transpose,
1504+
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1505+
/// {in_bounds = [true, true, true, true, true]}
1506+
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1507+
static LogicalResult
1508+
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1509+
ArrayRef<int64_t> inputVectorSizes,
1510+
SmallVectorImpl<Value> &newResults) {
1511+
OpBuilder::InsertionGuard g(rewriter);
1512+
rewriter.setInsertionPoint(packOp);
1513+
1514+
Location loc = packOp.getLoc();
1515+
auto padValue = packOp.getPaddingValue();
1516+
if (!padValue) {
1517+
padValue = rewriter.create<arith::ConstantOp>(
1518+
loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1519+
}
1520+
ReifiedRankedShapedTypeDims reifiedReturnShapes;
1521+
LogicalResult status =
1522+
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1523+
.reifyResultShapes(rewriter, reifiedReturnShapes);
1524+
(void)status; // prevent unused variable warning on non-assert builds.
1525+
assert(succeeded(status) && "failed to reify result shapes");
1526+
1527+
// Create masked TransferReadOp.
1528+
SmallVector<int64_t> inputShape(inputVectorSizes);
1529+
auto innerTiles = packOp.getStaticInnerTiles();
1530+
auto innerDimsPos = packOp.getInnerDimsPos();
1531+
auto outerDimsPerm = packOp.getOuterDimsPerm();
1532+
if (!outerDimsPerm.empty())
1533+
applyPermutationToVector(inputShape,
1534+
invertPermutationVector(outerDimsPerm));
1535+
for (auto [idx, size] : enumerate(innerTiles))
1536+
inputShape[innerDimsPos[idx]] *= size;
1537+
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
1538+
inputShape, padValue);
1539+
1540+
// Create ShapeCastOp.
1541+
SmallVector<int64_t> destShape(inputVectorSizes);
1542+
destShape.append(innerTiles.begin(), innerTiles.end());
1543+
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1544+
packOp.getDestType().getElementType());
1545+
auto shapeCastOp =
1546+
rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1547+
1548+
// Create TransposeOp.
1549+
auto destPermutation =
1550+
invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
1551+
auto transposeOp = rewriter.create<vector::TransposeOp>(
1552+
loc, shapeCastOp.getResult(), destPermutation);
1553+
1554+
// Create TransferWriteOp.
1555+
Operation *write =
1556+
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
1557+
reifiedReturnShapes[0], inputVectorSizes);
1558+
newResults.push_back(write->getResult(0));
1559+
return success();
1560+
}
1561+
13961562
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
13971563
/// and (3) all-zero lowPad to
13981564
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1402,9 +1568,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
14021568
SmallVectorImpl<Value> &newResults) {
14031569
auto padValue = padOp.getConstantPaddingValue();
14041570
Location loc = padOp.getLoc();
1405-
int64_t rank = inputVectorSizes.size();
1406-
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
1407-
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
14081571

14091572
// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
14101573
OpBuilder::InsertionGuard g(rewriter);
@@ -1416,36 +1579,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
14161579
.reifyResultShapes(rewriter, reifiedReturnShapes);
14171580
(void)status; // prevent unused variable warning on non-assert builds
14181581
assert(succeeded(status) && "failed to reify result shapes");
1419-
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
1420-
padValue.getType());
1421-
SmallVector<OpFoldResult> mixedSourceDims =
1422-
tensor::getMixedSizes(rewriter, loc, padOp.getSource());
1423-
Value mask =
1424-
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1425-
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1426-
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1427-
loc,
1428-
/*vectorType=*/vectorType,
1429-
/*source=*/padOp.getSource(),
1430-
/*indices=*/SmallVector<Value>(rank, zero),
1431-
/*padding=*/padValue,
1432-
/*inBounds=*/SmallVector<bool>(rank, true));
1433-
auto maskedOp = cast<vector::MaskOp>(
1434-
mlir::vector::maskOperation(rewriter, transferReadOp, mask));
1435-
Operation *write = rewriter.create<vector::TransferWriteOp>(
1436-
loc,
1437-
/*vector=*/maskedOp->getResult(0),
1438-
/*source=*/emptyOp,
1439-
/*indices=*/SmallVector<Value>(rank, zero),
1440-
/*inBounds=*/SmallVector<bool>(rank, true));
1441-
bool needMaskForWrite = llvm::any_of(
1442-
llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
1443-
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
1444-
if (needMaskForWrite) {
1445-
Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
1446-
loc, maskType, reifiedReturnShapes[0]);
1447-
write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
1448-
}
1582+
auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(),
1583+
inputVectorSizes, padValue);
1584+
Operation *write = createWriteOrMaskedWrite(
1585+
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
14491586
newResults.push_back(write->getResult(0));
14501587
return success();
14511588
}
@@ -1585,6 +1722,32 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
15851722
return success();
15861723
}
15871724

1725+
/// TODO: Use a matcher to check for a constant padding value.
1726+
static LogicalResult
1727+
vectorizePackOpPrecondition(tensor::PackOp packOp,
1728+
ArrayRef<int64_t> inputVectorSizes) {
1729+
auto padValue = packOp.getPaddingValue();
1730+
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
1731+
LDBG("pad value is not constant: " << packOp << "\n");
1732+
return failure();
1733+
}
1734+
1735+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1736+
if (failed(isValidMaskedInputVector(
1737+
resultTensorShape.take_front(packOp.getSourceRank()),
1738+
inputVectorSizes)))
1739+
return failure();
1740+
1741+
if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1742+
return !getConstantIntValue(v).has_value();
1743+
})) {
1744+
LDBG("inner_tiles must be constant: " << packOp << "\n");
1745+
return failure();
1746+
}
1747+
1748+
return success();
1749+
}
1750+
15881751
static LogicalResult
15891752
vectorizePadOpPrecondition(tensor::PadOp padOp,
15901753
ArrayRef<int64_t> inputVectorSizes) {
@@ -1644,6 +1807,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
16441807
.Case<tensor::PadOp>([&](auto padOp) {
16451808
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
16461809
})
1810+
.Case<tensor::PackOp>([&](auto packOp) {
1811+
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
1812+
})
16471813
.Default([](auto) { return failure(); });
16481814
}
16491815

@@ -1732,6 +1898,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
17321898
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
17331899
results);
17341900
})
1901+
.Case<tensor::PackOp>([&](auto packOp) {
1902+
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
1903+
results);
1904+
})
17351905
.Default([](auto) { return failure(); });
17361906

17371907
if (failed(vectorizeResult)) {

0 commit comments

Comments
 (0)