Skip to content

Commit 30d4f6a

Browse files
authored
Make createReadOrMaskedRead and isValidMaskedInputVector vector utilities (llvm#89119)
Made the createReadOrMaskedRead and isValidMaskedInputVector utility functions - to be accessible outside of the CU. Needed by the IREE new TopK implementation.
1 parent 28cea99 commit 30d4f6a

File tree

3 files changed

+113
-87
lines changed

3 files changed

+113
-87
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,30 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
180180
/// are not linearizable.
181181
bool isLinearizableVector(VectorType type);
182182

183+
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
184+
/// vector type for the read is not the same as the type of `source`, then a
185+
/// mask is created on the read, if use of mask is specified or the bounds on a
186+
/// dimension are different.
187+
///
188+
/// `useInBoundsInsteadOfMasking` if false, the inBoundsVal values are set
189+
/// properly, based on
190+
/// the rank dimensions of the source and destination tensors. And that is
191+
/// what determines if masking is done.
192+
///
193+
/// Note that the internal `vector::TransferReadOp` always read at indices zero
194+
/// for each dimension of the passed in tensor.
195+
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
196+
ArrayRef<int64_t> readShape, Value padValue,
197+
bool useInBoundsInsteadOfMasking = true);
198+
199+
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
200+
/// given `shape`, i.e., it meets:
201+
/// 1. The numbers of elements in both array are equal.
202+
/// 2. `inputVectorSizes` does not have dynamic dimensions.
203+
/// 3. All the values in `inputVectorSizes` are greater than or equal to
204+
/// static sizes in `shape`.
205+
LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
206+
ArrayRef<int64_t> inputVectorSizes);
183207
} // namespace vector
184208

185209
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 14 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,46 +1410,6 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14101410
return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
14111411
}
14121412

1413-
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
1414-
/// vector type for the read is not the same as the type of `source`, then a
1415-
/// mask is created on the read. If `doMasking` parameter is set to false we
1416-
/// update the `inBounds` attribute instead of masking.
1417-
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
1418-
Value source, ArrayRef<int64_t> readShape,
1419-
Value padValue, bool doMasking = true) {
1420-
assert(llvm::none_of(readShape,
1421-
[](int64_t s) { return s == ShapedType::kDynamic; }));
1422-
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
1423-
assert(sourceShape.size() == readShape.size());
1424-
auto maskType = VectorType::get(readShape, builder.getI1Type());
1425-
auto vectorType = VectorType::get(readShape, padValue.getType());
1426-
int64_t readRank = readShape.size();
1427-
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1428-
SmallVector<bool> inBoundsVal(readRank, true);
1429-
if (!doMasking) {
1430-
// Update the inBounds attribute.
1431-
for (unsigned i = 0; i < readRank; i++)
1432-
inBoundsVal[i] = sourceShape[i] == readShape[i];
1433-
}
1434-
auto transferReadOp = builder.create<vector::TransferReadOp>(
1435-
loc,
1436-
/*vectorType=*/vectorType,
1437-
/*source=*/source,
1438-
/*indices=*/SmallVector<Value>(readRank, zero),
1439-
/*padding=*/padValue,
1440-
/*inBounds=*/inBoundsVal);
1441-
1442-
if (llvm::equal(readShape, sourceShape) || !doMasking) {
1443-
return transferReadOp;
1444-
}
1445-
SmallVector<OpFoldResult> mixedSourceDims =
1446-
tensor::getMixedSizes(builder, loc, source);
1447-
Value mask =
1448-
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1449-
return mlir::vector::maskOperation(builder, transferReadOp, mask)
1450-
->getResult(0);
1451-
}
1452-
14531413
/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
14541414
/// create an empty destination tensor and create a TransferWriteOp from the
14551415
/// input to the empty tensor. If the destination shape is not the same as the
@@ -1539,11 +1499,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15391499
// If the input vector sizes are not provided, then the vector sizes are
15401500
// determined by the result tensor shape. In case the vector sizes aren't
15411501
// provided, we update the inBounds attribute instead of masking.
1542-
bool doMasking = true;
1502+
bool useInBoundsInsteadOfMasking = true;
15431503
if (inputVectorSizes.empty()) {
15441504
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
15451505
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1546-
doMasking = false;
1506+
useInBoundsInsteadOfMasking = false;
15471507
}
15481508

15491509
// Create masked TransferReadOp.
@@ -1556,8 +1516,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15561516
invertPermutationVector(outerDimsPerm));
15571517
for (auto [idx, size] : enumerate(innerTiles))
15581518
inputShape[innerDimsPos[idx]] *= size;
1559-
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
1560-
inputShape, padValue, doMasking);
1519+
auto maskedRead = vector::createReadOrMaskedRead(
1520+
rewriter, loc, packOp.getSource(), inputShape, padValue,
1521+
useInBoundsInsteadOfMasking);
15611522

15621523
// Create ShapeCastOp.
15631524
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1649,7 +1610,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16491610

16501611
// Read result, mask if necessary. If transferReadOp shape is not equal
16511612
// to shape of source, then a mask is necessary.
1652-
Value readResult = createReadOrMaskedRead(
1613+
Value readResult = vector::createReadOrMaskedRead(
16531614
rewriter, loc, unpackOp.getSource(),
16541615
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
16551616

@@ -1707,8 +1668,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
17071668
.reifyResultShapes(rewriter, reifiedReturnShapes);
17081669
(void)status; // prevent unused variable warning on non-assert builds
17091670
assert(succeeded(status) && "failed to reify result shapes");
1710-
auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(),
1711-
inputVectorSizes, padValue);
1671+
auto maskedRead = vector::createReadOrMaskedRead(
1672+
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue);
17121673
Operation *write = createWriteOrMaskedWrite(
17131674
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
17141675
newResults.push_back(write->getResult(0));
@@ -1781,41 +1742,6 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
17811742
return success();
17821743
}
17831744

1784-
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
1785-
/// given `shape`, i.e., it meets:
1786-
/// 1. The numbers of elements in both array are equal.
1787-
/// 2. `inputVectorSizes` does not have dynamic dimensions.
1788-
/// 3. All the values in `inputVectorSizes` are greater than or equal to
1789-
/// static sizes in `shape`.
1790-
static LogicalResult
1791-
isValidMaskedInputVector(ArrayRef<int64_t> shape,
1792-
ArrayRef<int64_t> inputVectorSizes) {
1793-
LDBG("Iteration space static sizes:");
1794-
LLVM_DEBUG(llvm::interleaveComma(shape, llvm::dbgs()));
1795-
LLVM_DEBUG(llvm::dbgs() << "\n");
1796-
1797-
if (inputVectorSizes.size() != shape.size()) {
1798-
LDBG("Input vector sizes don't match the number of loops");
1799-
return failure();
1800-
}
1801-
if (ShapedType::isDynamicShape(inputVectorSizes)) {
1802-
LDBG("Input vector sizes can't have dynamic dimensions");
1803-
return failure();
1804-
}
1805-
if (!llvm::all_of(llvm::zip(shape, inputVectorSizes),
1806-
[](std::tuple<int64_t, int64_t> sizePair) {
1807-
int64_t staticSize = std::get<0>(sizePair);
1808-
int64_t inputSize = std::get<1>(sizePair);
1809-
return ShapedType::isDynamic(staticSize) ||
1810-
staticSize <= inputSize;
1811-
})) {
1812-
LDBG("Input vector sizes must be greater than or equal to iteration space "
1813-
"static sizes");
1814-
return failure();
1815-
}
1816-
return success();
1817-
}
1818-
18191745
/// Need to check if the inner-tiles are static/constant.
18201746
static LogicalResult
18211747
vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
@@ -1829,7 +1755,7 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
18291755
}
18301756
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
18311757
if (!inputVectorSizes.empty() &&
1832-
failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
1758+
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
18331759
return failure();
18341760

18351761
return success();
@@ -1843,8 +1769,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
18431769
return failure();
18441770
// Check API contract for input vector sizes.
18451771
if (!inputVectorSizes.empty() &&
1846-
failed(isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1847-
inputVectorSizes)))
1772+
failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1773+
inputVectorSizes)))
18481774
return failure();
18491775

18501776
if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
@@ -1920,7 +1846,7 @@ vectorizePackOpPrecondition(tensor::PackOp packOp,
19201846
}
19211847

19221848
if (!satisfyEmptyCond &&
1923-
failed(isValidMaskedInputVector(
1849+
failed(vector::isValidMaskedInputVector(
19241850
resultTensorShape.take_front(packOp.getSourceRank()),
19251851
inputVectorSizes)))
19261852
return failure();
@@ -1945,7 +1871,8 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
19451871
}
19461872

19471873
ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1948-
if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
1874+
if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1875+
inputVectorSizes)))
19491876
return failure();
19501877

19511878
if (llvm::any_of(padOp.getLow(), [](Value v) {

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
#include "llvm/ADT/DenseSet.h"
3131
#include "llvm/ADT/SetVector.h"
3232

33+
#define DEBUG_TYPE "vector-utils"
34+
35+
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
37+
3338
using namespace mlir;
3439

3540
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
@@ -322,3 +327,73 @@ bool vector::isLinearizableVector(VectorType type) {
322327
auto numScalableDims = llvm::count(type.getScalableDims(), true);
323328
return (type.getRank() > 1) && (numScalableDims <= 1);
324329
}
330+
331+
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
332+
Value source, ArrayRef<int64_t> readShape,
333+
Value padValue,
334+
bool useInBoundsInsteadOfMasking) {
335+
assert(llvm::none_of(readShape,
336+
[](int64_t s) { return s == ShapedType::kDynamic; }) &&
337+
"expected static shape");
338+
auto sourceShapedType = cast<ShapedType>(source.getType());
339+
auto sourceShape = sourceShapedType.getShape();
340+
assert(sourceShape.size() == readShape.size() && "expected same ranks.");
341+
auto maskType = VectorType::get(readShape, builder.getI1Type());
342+
auto vectorType = VectorType::get(readShape, padValue.getType());
343+
assert(padValue.getType() == sourceShapedType.getElementType() &&
344+
"expected same pad element type to match source element type");
345+
int64_t readRank = readShape.size();
346+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
347+
SmallVector<bool> inBoundsVal(readRank, true);
348+
if (!useInBoundsInsteadOfMasking) {
349+
// Update the inBounds attribute.
350+
for (unsigned i = 0; i < readRank; i++)
351+
inBoundsVal[i] = (sourceShape[i] == readShape[i]) &&
352+
!ShapedType::isDynamic(sourceShape[i]);
353+
}
354+
auto transferReadOp = builder.create<vector::TransferReadOp>(
355+
loc,
356+
/*vectorType=*/vectorType,
357+
/*source=*/source,
358+
/*indices=*/SmallVector<Value>(readRank, zero),
359+
/*padding=*/padValue,
360+
/*inBounds=*/inBoundsVal);
361+
362+
if (llvm::equal(readShape, sourceShape) || !useInBoundsInsteadOfMasking)
363+
return transferReadOp;
364+
SmallVector<OpFoldResult> mixedSourceDims =
365+
tensor::getMixedSizes(builder, loc, source);
366+
Value mask =
367+
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
368+
return mlir::vector::maskOperation(builder, transferReadOp, mask)
369+
->getResult(0);
370+
}
371+
372+
LogicalResult
373+
vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
374+
ArrayRef<int64_t> inputVectorSizes) {
375+
LDBG("Iteration space static sizes:");
376+
LLVM_DEBUG(llvm::interleaveComma(shape, llvm::dbgs()));
377+
LLVM_DEBUG(llvm::dbgs() << "\n");
378+
379+
if (inputVectorSizes.size() != shape.size()) {
380+
LDBG("Input vector sizes don't match the number of loops");
381+
return failure();
382+
}
383+
if (ShapedType::isDynamicShape(inputVectorSizes)) {
384+
LDBG("Input vector sizes can't have dynamic dimensions");
385+
return failure();
386+
}
387+
if (!llvm::all_of(llvm::zip(shape, inputVectorSizes),
388+
[](std::tuple<int64_t, int64_t> sizePair) {
389+
int64_t staticSize = std::get<0>(sizePair);
390+
int64_t inputSize = std::get<1>(sizePair);
391+
return ShapedType::isDynamic(staticSize) ||
392+
staticSize <= inputSize;
393+
})) {
394+
LDBG("Input vector sizes must be greater than or equal to iteration space "
395+
"static sizes");
396+
return failure();
397+
}
398+
return success();
399+
}

0 commit comments

Comments
 (0)