Skip to content

Commit 26006e1

Browse files
committed
fixup! fixup! [mlir][linalg] Add support for masked vectorization of tensor.insert_slice (1/N)
Incorporate suggestions from Diego
1 parent c44faa0 commit 26006e1

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
8383
ArrayRef<int64_t> inputVectorSizes,
8484
SmallVectorImpl<Value> &newResults);
8585

86+
/// Returns the effective Pad value for the input op, provided it's a scalar.
87+
///
88+
/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
89+
/// this Op performs padding, retrieve the padding value provided that it's
90+
/// a scalar and static/fixed for all the padded values. Returns an empty value
91+
/// otherwise.
92+
static Value getStaticPadVal(Operation *op);
93+
8694
/// Return the unique instance of OpType in `block` if it is indeed unique.
8795
/// Return null if none or more than 1 instances exist.
8896
template <typename OpType>
@@ -1904,8 +1912,31 @@ static LogicalResult
19041912
vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
19051913
ArrayRef<int64_t> inputVectorSizes) {
19061914

1907-
// TODO: Move pre-conditions from the vectorization logic, i.e.
1908-
// vectorizeAsInsertSliceOp.
1915+
TypedValue<RankedTensorType> source = sliceOp.getSource();
1916+
auto sourceType = source.getType();
1917+
if (!VectorType::isValidElementType(sourceType.getElementType()))
1918+
return failure();
1919+
1920+
// Get the pad value.
1921+
// TransferReadOp (which is used to vectorize InsertSliceOp, requires a scalar
1922+
// padding value. Note that:
1923+
// * for in-bounds access, the value is actually irrelevant.
1924+
// There are 2 cases in which xfer.read accesses are known to be in-bounds:
1925+
// 1. The source shape is static (output vector sizes would be based on
1926+
// the source shape and hence all memory accesses would be in-bounds),
1927+
// 2. Masking is used (output vector sizes would be user-provided, in which
1928+
// case it is assumed that all memory accesses are in-bounds). This
1929+
// remains a TODO.
1930+
//
1931+
// When the value is not known and not needed, use 0. Otherwise, bail out.
1932+
Value padValue = getStaticPadVal(sliceOp);
1933+
bool isOutOfBoundsRead =
1934+
!sourceType.hasStaticShape() && inputVectorSizes.empty();
1935+
1936+
if (!padValue && isOutOfBoundsRead) {
1937+
LDBG("Failed to get a pad value for out-of-bounds read access\n");
1938+
return failure();
1939+
}
19091940
return success();
19101941
}
19111942

@@ -2216,7 +2247,6 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
22162247
ArrayRef<bool> inputScalableVecDims,
22172248
bool vectorizeNDExtract,
22182249
bool flatten1DDepthwiseConv) {
2219-
rewriter.getInsertionPoint();
22202250
LDBG("Attempting to vectorize:\n" << *op << "\n");
22212251
LDBG("Input vector sizes: ");
22222252
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2583,6 +2613,9 @@ struct PadOpVectorizationWithTransferWritePattern
25832613
/// this Op performs padding, retrieve the padding value provided that it's
25842614
/// a scalar and static/fixed for all the padded values. Returns an empty value
25852615
/// otherwise.
2616+
///
2617+
/// TODO: This is used twice (when checking vectorization pre-conditions and
2618+
/// when vectorizing). Cache results instead of re-running.
25862619
static Value getStaticPadVal(Operation *op) {
25872620
if (!op)
25882621
return {};
@@ -2636,30 +2669,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
26362669

26372670
TypedValue<RankedTensorType> source = sliceOp.getSource();
26382671
auto sourceType = source.getType();
2639-
if (!VectorType::isValidElementType(sourceType.getElementType()))
2640-
return failure();
2641-
26422672
auto resultType = sliceOp.getResultType();
26432673

2644-
// 1. Get the pad value.
2645-
// TransferReadOp requires a scalar padding value. Note that:
2646-
// * for in-bounds access, the value is actually irrelevant.
2647-
// There are 2 cases in which xfer.read accesses are known to be in-bounds:
2648-
// 1. The source shape is static (output vector sizes would be based on
2649-
// the source shape and hence all memory accesses would be in-bounds),
2650-
// 2. Masking is used (output vector sizes would be user-provided, in which
2651-
// case it is assumed that all memory accesses are in-bounds). This
2652-
// remains a TODO.
2653-
//
2654-
// When the value is not known and not needed, use 0. Otherwise, bail out.
26552674
Value padValue = getStaticPadVal(sliceOp);
2656-
bool isOutOfBoundsRead =
2657-
!sourceType.hasStaticShape() && inputVectorSizes.empty();
2658-
2659-
if (!padValue && isOutOfBoundsRead) {
2660-
LDBG("Failed to get a pad value for out-of-bounds read access\n");
2661-
return failure();
2662-
}
26632675

26642676
if (!padValue) {
26652677
auto elemType = sourceType.getElementType();
@@ -2672,7 +2684,7 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
26722684
SmallVector<bool> readInBounds;
26732685
SmallVector<bool> writeInBounds;
26742686
size_t rankDiff = resultType.getRank() - sourceType.getRank();
2675-
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2687+
for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
26762688
if (!inputVectorSizes.empty()) {
26772689
vecShape.push_back(inputVectorSizes[i]);
26782690
readInBounds.push_back(false);

0 commit comments

Comments
 (0)