Skip to content

Commit 9acc090

Browse files
committed
address comments
1 parent c776290 commit 9acc090

File tree

5 files changed

+164
-164
lines changed

5 files changed

+164
-164
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ FailureOr<RankedTensorType>
3232
computeTransposedType(RankedTensorType rankedTensorType,
3333
ArrayRef<int64_t> transposeVector);
3434

35+
/// Given a tensor::PackOp, compute the permutation vector to shuffle the
36+
/// packed shape into the shape before any outer or inner permutations have
37+
/// been applied.
38+
/// i.e. for a pack from an ABCD layout to an ABCDba:
39+
/// The packed shape would be ABCDba.
40+
/// The pre-permutation shape would be AaBbCD.
41+
SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
42+
3543
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
3644
/// source tensor or inserts the source tensor into a destination tensor with
3745
/// the same shape.

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -233,31 +233,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
233233
rewriter.setInsertionPoint(packOp);
234234

235235
// 2. Compute the permutation vector to shuffle packed shape into the shape
236-
// before any outer or inner permutations have been applied. The permutation
237-
// can be obtained from two permutations:
238-
// a) Compute the permutation vector to move the last `numPackedDims` into
239-
// the `innerPosDims` of a shape of rank `packedRank`.
240-
// b) Compute the permutation vector to move outer dims if the pack op
241-
// has outer_dims_perm.
242-
// Apply (b) permutation on (a) permutation to get the final permutation.
243-
int64_t numPackedDims = packOp.getInnerDimsPos().size();
244-
int64_t packedRank = packedTensorType.getRank();
245-
auto lastDims = llvm::to_vector(
246-
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
236+
// before any outer or inner permutations have been applied.
247237
PackingMetadata packingMetadata = computePackingMetadata(
248238
packedTensorType.getRank(), packOp.getInnerDimsPos());
249-
SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
250-
packedRank, lastDims, packingMetadata.insertPositions);
251-
252-
SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
253-
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
254-
if (!outerPerm.empty())
255-
applyPermutationToVector(outerPos, outerPerm);
256-
SmallVector<int64_t> outerPositionPerm = computePermutationVector(
257-
packedRank, packingMetadata.outerPositions, outerPos);
258-
259-
SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
260-
applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
239+
SmallVector<int64_t> packedToStripMinedShapePerm =
240+
tensor::getPackInverseDestPermutation(packOp);
261241

262242
// 3. Compute the stripMinedShape: this is the packed shape before any outer
263243
// or inner permutations have been applied.
@@ -304,10 +284,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
304284
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
305285
DBGS() << "packedShape: ");
306286
DBGSNL();
307-
llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
308-
DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
309-
DBGS() << "innerPositionsPerm: ");
310-
DBGSNL();
311287
llvm::interleaveComma(packedToStripMinedShapePerm,
312288
DBGS() << "packedToStripMinedShapePerm: ");
313289
DBGSNL(); llvm::interleaveComma(
@@ -332,9 +308,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
332308
auto emptyOp =
333309
rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
334310
// Offsets.
335-
SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
311+
SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
312+
rewriter.getIndexAttr(0));
336313
// Strides.
337-
SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
314+
SmallVector<OpFoldResult> ones(packOp.getDestRank(),
315+
rewriter.getIndexAttr(1));
338316
SmallVector<OpFoldResult> sizes =
339317
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
340318

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 74 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2020
#include "mlir/Dialect/Linalg/Utils/Utils.h"
2121
#include "mlir/Dialect/Tensor/IR/Tensor.h"
22+
#include "mlir/Dialect/Tensor/Utils/Utils.h"
2223
#include "mlir/Dialect/Utils/IndexingUtils.h"
2324
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2425
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -1400,83 +1401,47 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14001401
return success();
14011402
}
14021403

1403-
/// Given a tensor::PackOp, return the permutation from the "tiled"
1404-
/// shape to the "packed" shape, defined as the following:
1405-
/// The "packed" shape is the same as the `dest` shape of the pack op.
1406-
/// The "tiled" shape is a permutation of the `dest` shape such that
1407-
/// each outer dimension is in the original `source` order, and the
1408-
/// inner_tile dimensions immediately follow their corresponding outer
1409-
/// dimension.
1410-
/// i.e. for the following tensor.pack:
1411-
/// ```mlir
1412-
/// %pack = tensor.pack %0 padding_value(%1)
1413-
/// outer_dims_perm = [0, 2, 1]
1414-
/// inner_dims_pos = [2, 1]
1415-
/// inner_tiles = [16, 2]
1416-
/// into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
1417-
/// ```
1418-
/// The "packed" shape is `32x1x4x16x2`
1419-
/// The "tiled" shape is `32x(4x2)x(1x16)`
1420-
static SmallVector<int64_t>
1421-
getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
1422-
auto innerTiles = packOp.getInnerTiles();
1423-
int64_t srcRank = packOp.getSourceRank();
1424-
auto innerDimsPos = packOp.getInnerDimsPos();
1425-
if (innerDimsPos.empty())
1426-
innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
1427-
auto outerDimsPerm = packOp.getOuterDimsPerm();
1428-
if (outerDimsPerm.empty())
1429-
outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
1430-
auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
1431-
int64_t srcIdx;
1432-
if (idx >= srcRank)
1433-
srcIdx = innerDimsPos[idx - srcRank];
1434-
else
1435-
srcIdx = outerDimsPerm[idx];
1436-
int64_t tiledIdx = srcIdx;
1437-
for (int64_t pos : innerDimsPos)
1438-
if (pos < srcIdx)
1439-
tiledIdx++;
1440-
if (idx >= srcRank)
1441-
tiledIdx++;
1442-
return tiledIdx;
1443-
};
1444-
SmallVector<int64_t> perm;
1445-
for (size_t i = 0; i < packOp.getDestRank(); i++)
1446-
perm.push_back(packedIdxToTiledIdx(i));
1447-
return perm;
1448-
}
1449-
1450-
/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
1451-
/// above in `getTiledShapeToPackedShapePerm`.
1404+
/// Given a tensor::PackOp, return the `dest` shape before any packing
1405+
/// permutations.
14521406
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14531407
ArrayRef<int64_t> destShape) {
1454-
auto perm = getTiledShapeToPackedShapePerm(packOp);
1455-
return applyPermutation(destShape, invertPermutationVector(perm));
1408+
return applyPermutation(destShape,
1409+
tensor::getPackInverseDestPermutation(packOp));
14561410
}
14571411

1458-
/// Create a masked TransferReadOp from `source` with shape `readShape`.
1459-
static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
1460-
Value source,
1461-
ArrayRef<int64_t> readShape,
1462-
Value padValue) {
1412+
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
1413+
/// vector type for the read is not the same as the type of `source`, then a
1414+
/// mask is created on the read.
1415+
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
1416+
Value source, ArrayRef<int64_t> readShape,
1417+
Value padValue) {
1418+
assert(llvm::none_of(readShape,
1419+
[](int64_t s) { return s == ShapedType::kDynamic; }));
14631420
auto maskType = VectorType::get(readShape, builder.getI1Type());
14641421
auto vectorType = VectorType::get(readShape, padValue.getType());
1465-
SmallVector<OpFoldResult> mixedSourceDims =
1466-
tensor::getMixedSizes(builder, loc, source);
1467-
Value mask =
1468-
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1469-
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
14701422
int64_t readRank = readShape.size();
1423+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
14711424
auto transferReadOp = builder.create<vector::TransferReadOp>(
14721425
loc,
14731426
/*vectorType=*/vectorType,
14741427
/*source=*/source,
14751428
/*indices=*/SmallVector<Value>(readRank, zero),
14761429
/*padding=*/padValue,
14771430
/*inBounds=*/SmallVector<bool>(readRank, true));
1478-
return cast<vector::MaskOp>(
1479-
mlir::vector::maskOperation(builder, transferReadOp, mask));
1431+
auto sourceShape = llvm::dyn_cast<ShapedType>(source.getType()).getShape();
1432+
if (sourceShape.size() == readShape.size() &&
1433+
llvm::all_of(llvm::zip_equal(readShape, sourceShape), [](auto it) {
1434+
return std::get<0>(it) != ShapedType::kDynamic &&
1435+
std::get<0>(it) == std::get<1>(it);
1436+
})) {
1437+
return transferReadOp;
1438+
}
1439+
SmallVector<OpFoldResult> mixedSourceDims =
1440+
tensor::getMixedSizes(builder, loc, source);
1441+
Value mask =
1442+
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1443+
return mlir::vector::maskOperation(builder, transferReadOp, mask)
1444+
->getResult(0);
14801445
}
14811446

14821447
/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
@@ -1500,9 +1465,14 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15001465
/*indices=*/SmallVector<Value>(rank, zero),
15011466
/*inBounds=*/SmallVector<bool>(rank, true));
15021467
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1503-
bool needMaskForWrite =
1504-
llvm::any_of(llvm::zip(inputVectorSizes, destShape),
1505-
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
1468+
assert(llvm::none_of(
1469+
destShape.drop_front(inputVectorSizes.size()),
1470+
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1471+
"Only dims aligned with inputVectorSizes may be dynamic");
1472+
bool needMaskForWrite = llvm::any_of(
1473+
llvm::zip_equal(inputVectorSizes,
1474+
destShape.take_front(inputVectorSizes.size())),
1475+
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
15061476
if (needMaskForWrite) {
15071477
SmallVector<int64_t> writeMaskShape;
15081478
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
@@ -1517,11 +1487,28 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15171487
}
15181488

15191489
/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1520-
/// padding value into
1521-
/// transfer_write_in_bounds(
1522-
/// transpose(
1523-
/// shape_cast(
1524-
/// transfer_read_masked(pack_source, pad_value))))
1490+
/// padding value into:
1491+
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1492+
/// As in the following example:
1493+
/// ```mlir
1494+
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1495+
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1496+
/// ```
1497+
/// This pack would be vectorized to:
1498+
/// ```mlir
1499+
/// %load = vector.mask %mask {
1500+
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1501+
/// {in_bounds = [true, true, true]} :
1502+
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
1503+
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1504+
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1505+
/// to vector<32x4x2x1x16xf32>
1506+
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1507+
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1508+
/// %write = vector.transfer_write %transpose,
1509+
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1510+
/// {in_bounds = [true, true, true, true, true]}
1511+
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
15251512
static LogicalResult
15261513
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15271514
ArrayRef<int64_t> inputVectorSizes,
@@ -1539,10 +1526,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15391526
LogicalResult status =
15401527
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
15411528
.reifyResultShapes(rewriter, reifiedReturnShapes);
1542-
(void)status; // prevent unused variable warning on non-assert builds
1529+
(void)status; // prevent unused variable warning on non-assert builds.
15431530
assert(succeeded(status) && "failed to reify result shapes");
15441531

1545-
// Create masked TransferReadOp
1532+
// Create masked TransferReadOp.
15461533
SmallVector<int64_t> inputShape(inputVectorSizes);
15471534
auto innerTiles = packOp.getStaticInnerTiles();
15481535
auto innerDimsPos = packOp.getInnerDimsPos();
@@ -1552,23 +1539,24 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15521539
invertPermutationVector(outerDimsPerm));
15531540
for (auto [idx, size] : enumerate(innerTiles))
15541541
inputShape[innerDimsPos[idx]] *= size;
1555-
auto maskedOp = createMaskedTransferRead(rewriter, loc, packOp.getSource(),
1542+
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
15561543
inputShape, padValue);
15571544

1558-
// Create ShapeCastOp
1545+
// Create ShapeCastOp.
15591546
SmallVector<int64_t> destShape(inputVectorSizes);
15601547
destShape.append(innerTiles.begin(), innerTiles.end());
15611548
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
15621549
packOp.getDestType().getElementType());
1563-
auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1564-
loc, tiledPackType, maskedOp->getResult(0));
1550+
auto shapeCastOp =
1551+
rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
15651552

1566-
// Create TransposeOp
1567-
auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
1553+
// Create TransposeOp.
1554+
auto destPermutation =
1555+
invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
15681556
auto transposeOp = rewriter.create<vector::TransposeOp>(
1569-
loc, shapeCastOp.getResult(), tiledShapeToPackedShapePerm);
1557+
loc, shapeCastOp.getResult(), destPermutation);
15701558

1571-
// Create TransferWriteOp
1559+
// Create TransferWriteOp.
15721560
Operation *write =
15731561
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
15741562
reifiedReturnShapes[0], inputVectorSizes);
@@ -1596,11 +1584,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
15961584
.reifyResultShapes(rewriter, reifiedReturnShapes);
15971585
(void)status; // prevent unused variable warning on non-assert builds
15981586
assert(succeeded(status) && "failed to reify result shapes");
1599-
auto maskedOp = createMaskedTransferRead(rewriter, loc, padOp.getSource(),
1587+
auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(),
16001588
inputVectorSizes, padValue);
1601-
Operation *write =
1602-
createWriteOrMaskedWrite(rewriter, loc, maskedOp->getResult(0),
1603-
reifiedReturnShapes[0], inputVectorSizes);
1589+
Operation *write = createWriteOrMaskedWrite(
1590+
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
16041591
newResults.push_back(write->getResult(0));
16051592
return success();
16061593
}
@@ -1740,11 +1727,12 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
17401727
return success();
17411728
}
17421729

1730+
/// TODO: Use a matcher to check for a constant padding value.
17431731
static LogicalResult
17441732
vectorizePackOpPrecondition(tensor::PackOp packOp,
17451733
ArrayRef<int64_t> inputVectorSizes) {
17461734
auto padValue = packOp.getPaddingValue();
1747-
if (padValue && !getConstantIntValue(padValue).has_value()) {
1735+
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
17481736
LDBG("pad value is not constant: " << packOp << "\n");
17491737
return failure();
17501738
}

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,35 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
7373
return transposedTensorType;
7474
}
7575

76+
SmallVector<int64_t>
77+
mlir::tensor::getPackInverseDestPermutation(PackOp packOp) {
78+
// The permutation can be obtained from two permutations:
79+
// a) Compute the permutation vector to move the last `numPackedDims` into
80+
// the `innerPosDims` of a shape of rank `packedRank`.
81+
// b) Compute the permutation vector to move outer dims if the pack op
82+
// has outer_dims_perm.
83+
// Apply (b) permutation on (a) permutation to get the final permutation.
84+
int64_t numPackedDims = packOp.getInnerDimsPos().size();
85+
int64_t packedRank = packOp.getDestType().getRank();
86+
auto lastDims = llvm::to_vector(
87+
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
88+
PackingMetadata packingMetadata = computePackingMetadata(
89+
packOp.getDestType().getRank(), packOp.getInnerDimsPos());
90+
SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
91+
packedRank, lastDims, packingMetadata.insertPositions);
92+
93+
SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
94+
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
95+
if (!outerPerm.empty())
96+
applyPermutationToVector(outerPos, outerPerm);
97+
SmallVector<int64_t> outerPositionPerm = computePermutationVector(
98+
packedRank, packingMetadata.outerPositions, outerPos);
99+
100+
SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
101+
applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
102+
return packInverseDestPermutation;
103+
}
104+
76105
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
77106
llvm::SmallBitVector droppedDims = op.getDroppedDims();
78107
int64_t srcDim = 0;

0 commit comments

Comments
 (0)