Skip to content

Commit e7aa443

Browse files
committed
[mlir][linalg] Split GenericPadOpVectorizationPattern into two patterns
At the moment, `GenericPadOpVectorizationPattern` implements two orthogonal transformations: 1. Rewrites `tensor::PadOp` into a sequence of `tensor::EmptyOp`, `linalg::FillOp` and `tensor::InsertSliceOp`. 2. Vectorizes (where possible) `tensor::InsertSliceOp` (see `tryVectorizeCopy`). This patch splits `GenericPadOpVectorizationPattern` into two separate patterns: 1. `GeneralizePadOpPattern` for the first transformation (note that currently `GenericPadOpVectorizationPattern` inherits from `GeneralizePadOpPattern`). 2. `InsertSliceVectorizePattern` to vectorize `tensor::InsertSliceOp`. With this change, we gain the following: * a clear separation between pre-processing and vectorization transformations/stages, * a path to support masked vectorisation for `tensor.insert_slice` (with a dedicated pattern for vectorization, it is much easier to specify the input vector sizes used in masking), * more opportunities to vectorize `tensor.insert_slice`. Note for downstream users: -------------------------- If you were using `populatePadOpVectorizationPatterns`, following this change you will also have to add `populateInsertSliceVectorizationPatterns`. Finer implementation details: ----------------------------- 1. The majority of changes in this patch are copy & paste + some edits. 1.1 The only functional change is that the vectorization of `tensor.insert_slice` is now broadly available (as opposed to being constrained to the pad vectorization pattern: `GenericPadOpVectorizationPattern`). 1.2 Following-on from the above, `@pad_and_insert_slice_dest` is updated. As expected, the input `tensor.insert_slice` Op is no longer "preserved" and instead gets vectorized successfully. 2. The `linalg.fill` case in `getConstantPadVal` works under the assumption that only _scalar_ source values can be used. That's consistent with the definition of the Op, but it's not tested at the moment. Hence a test case in Linalg/invalid.mlir is added. 3. The behaviour of the two TD vectorization Ops, `transform.structured.vectorize_children_and_apply_patterns` and `transform.structured.vectorize` is preserved.
1 parent 89d2a9d commit e7aa443

File tree

6 files changed

+239
-134
lines changed

6 files changed

+239
-134
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,18 +1503,13 @@ using OptimizeCopyFn =
15031503

15041504
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
15051505
/// InsertSliceOp. For now, only constant padding values are supported.
1506-
/// `OptimizeCopyFn` can be used to customize copying step optimization.
15071506
struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
1508-
GeneralizePadOpPattern(MLIRContext *context,
1509-
OptimizeCopyFn optimizeCopyFn = nullptr,
1510-
PatternBenefit benefit = 1)
1511-
: OpRewritePattern<tensor::PadOp>(context, benefit),
1512-
optimizeCopyFn(std::move(optimizeCopyFn)) {}
1507+
GeneralizePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
1508+
: OpRewritePattern<tensor::PadOp>(context, benefit) {}
15131509
LogicalResult matchAndRewrite(tensor::PadOp padOp,
15141510
PatternRewriter &rewriter) const override;
15151511

15161512
protected:
1517-
OptimizeCopyFn optimizeCopyFn;
15181513
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp,
15191514
Value dest,
15201515
const SmallVector<Value> &dynSizes) const;
@@ -1663,6 +1658,11 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
16631658
/// \see rewriteInIm2Col for more details.
16641659
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
16651660

1661+
/// Populates `patterns` with vectorisation patterns for tensor.insert_slice.
1662+
/// TODO: Avoid having a dedicated `populate{}` for one pattern. Instead, either
1663+
/// expand or merge with other `populate{}`.
1664+
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns);
1665+
16661666
/// Populates `patterns` with patterns that vectorize tensor.pad.
16671667
/// These patterns are meant to apply in a complementary fashion. Benefits
16681668
/// are used to encode a certain ordering of pattern application. To avoid

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3477,6 +3477,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
34773477

34783478
patterns.add<CopyVectorizationPattern>(ctx);
34793479

3480+
// Add misc. vectorization patterns (e.g. for tensor.insert_slice)
3481+
linalg::populateInsertSliceVectorizationPatterns(patterns);
3482+
34803483
if (getVectorizePadding())
34813484
linalg::populatePadOpVectorizationPatterns(patterns);
34823485

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -973,12 +973,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
973973
padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
974974
Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
975975

976-
// Try optimize the copy of source.
977-
if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
978-
return success();
979-
980-
// tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
981-
// for copying the PadOp source.
976+
// Generate a InsertSliceOp for copying the PadOp source.
982977
auto sourceType = padOp.getSourceType();
983978
// Compute size of source of tensor::PadOp.
984979
SmallVector<OpFoldResult> srcSizes =

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 177 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,115 +2262,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
22622262
//----------------------------------------------------------------------------//
22632263
// Misc. vectorization patterns.
22642264
//----------------------------------------------------------------------------//
2265-
2266-
/// Helper function that retrieves the value of an IntegerAttr.
2267-
static int64_t getIntFromAttr(Attribute attr) {
2268-
return cast<IntegerAttr>(attr).getInt();
2269-
}
2270-
2271-
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
2272-
/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2273-
/// not supported.
2274-
static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
2275-
ArrayRef<OpFoldResult> ofrs) {
2276-
SmallVector<Value> result;
2277-
for (auto o : ofrs) {
2278-
if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2279-
result.push_back(val);
2280-
} else {
2281-
result.push_back(rewriter.create<arith::ConstantIndexOp>(
2282-
loc, getIntFromAttr(o.template get<Attribute>())));
2283-
}
2284-
}
2285-
return result;
2286-
}
2287-
2288-
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2289-
/// InsertSliceOp. For now, only constant padding values are supported.
2290-
/// If there is enough static type information, TransferReadOps and
2291-
/// TransferWriteOps may be generated instead of InsertSliceOps.
2292-
struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2293-
GenericPadOpVectorizationPattern(MLIRContext *context,
2294-
PatternBenefit benefit = 1)
2295-
: GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2296-
/// Vectorize the copying of a tensor::PadOp's source. This is possible if
2297-
/// each dimension size is statically know in the source type or the result
2298-
/// type (or both).
2299-
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
2300-
tensor::PadOp padOp, Value dest) {
2301-
auto sourceType = padOp.getSourceType();
2302-
auto resultType = padOp.getResultType();
2303-
if (!VectorType::isValidElementType(sourceType.getElementType()))
2304-
return failure();
2305-
2306-
// Copy cannot be vectorized if pad value is non-constant and source shape
2307-
// is dynamic. In case of a dynamic source shape, padding must be appended
2308-
// by TransferReadOp, but TransferReadOp supports only constant padding.
2309-
auto padValue = padOp.getConstantPaddingValue();
2310-
if (!padValue) {
2311-
if (!sourceType.hasStaticShape())
2312-
return failure();
2313-
// Create dummy padding value.
2314-
auto elemType = sourceType.getElementType();
2315-
padValue = rewriter.create<arith::ConstantOp>(
2316-
padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2317-
}
2318-
2319-
SmallVector<int64_t> vecShape;
2320-
SmallVector<bool> readInBounds;
2321-
SmallVector<bool> writeInBounds;
2322-
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2323-
if (!sourceType.isDynamicDim(i)) {
2324-
vecShape.push_back(sourceType.getDimSize(i));
2325-
// Source shape is statically known: Neither read nor write are
2326-
// out-of- bounds.
2327-
readInBounds.push_back(true);
2328-
writeInBounds.push_back(true);
2329-
} else if (!resultType.isDynamicDim(i)) {
2330-
// Source shape is not statically known, but result shape is.
2331-
// Vectorize with size of result shape. This may be larger than the
2332-
// source size.
2333-
vecShape.push_back(resultType.getDimSize(i));
2334-
// Read may be out-of-bounds because the result size could be larger
2335-
// than the source size.
2336-
readInBounds.push_back(false);
2337-
// Write is out-of-bounds if low padding > 0.
2338-
writeInBounds.push_back(
2339-
getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2340-
static_cast<int64_t>(0));
2341-
} else {
2342-
// Neither source nor result dim of padOp is static. Cannot vectorize
2343-
// the copy.
2344-
return failure();
2345-
}
2346-
}
2347-
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2348-
2349-
// Generate TransferReadOp.
2350-
SmallVector<Value> readIndices(
2351-
vecType.getRank(),
2352-
rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2353-
auto read = rewriter.create<vector::TransferReadOp>(
2354-
padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2355-
ArrayRef<bool>{readInBounds});
2356-
2357-
// If `dest` is a FillOp and the TransferWriteOp would overwrite the
2358-
// entire tensor, write directly to the FillOp's operand.
2359-
if (llvm::equal(vecShape, resultType.getShape()) &&
2360-
llvm::all_of(writeInBounds, [](bool b) { return b; }))
2361-
if (auto fill = dest.getDefiningOp<FillOp>())
2362-
dest = fill.output();
2363-
2364-
// Generate TransferWriteOp.
2365-
auto writeIndices =
2366-
ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2367-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2368-
padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2369-
2370-
return success();
2371-
}
2372-
};
2373-
23742265
/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
23752266
/// given operation type OpTy.
23762267
template <typename OpTy>
@@ -2604,6 +2495,177 @@ struct PadOpVectorizationWithTransferWritePattern
26042495
}
26052496
};
26062497

2498+
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
2499+
/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2500+
/// not supported.
2501+
static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
2502+
ArrayRef<OpFoldResult> ofrs) {
2503+
SmallVector<Value> result;
2504+
for (auto o : ofrs) {
2505+
if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2506+
result.push_back(val);
2507+
} else {
2508+
result.push_back(rewriter.create<arith::ConstantIndexOp>(
2509+
loc, cast<IntegerAttr>(cast<Attribute>(o)).getInt()));
2510+
}
2511+
}
2512+
return result;
2513+
}
2514+
2515+
/// Returns the effective Pad value for the input op, provided it's a scalar.
2516+
///
2517+
/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2518+
/// this Op performs padding, retrieve the padding value provided that it's
2519+
/// a scalar and static/fixed for all the padded values. Returns an empty value
2520+
/// otherwise.
2521+
static Value getStaticPadVl(Operation *op) {
2522+
if (!op)
2523+
return {};
2524+
2525+
// 1. vector.broadcast - return the value that's being broadcast,
2526+
// provided that it's a scalar.
2527+
if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2528+
auto source = bcast.getSource();
2529+
if (llvm::dyn_cast<VectorType>(source.getType()))
2530+
return {};
2531+
2532+
return source;
2533+
}
2534+
2535+
// 1. linalg.fill - use the scalar input value that used to fill the output
2536+
// tensor.
2537+
if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2538+
return fill.getInputs()[0];
2539+
}
2540+
2541+
// 2. tensor.generateOp - can't guarantee the value is fixed without
2542+
// analysing, bail out.
2543+
if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2544+
return {};
2545+
}
2546+
2547+
// 3. vector.transfer_write - inspect the input vector that's written from. If
2548+
// if contains a single value that has been broadcast (e.g. via
2549+
// vector.broadcast), extract it, fail otherwise.
2550+
if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2551+
return getStaticPadVl(xferWrite.getVector().getDefiningOp());
2552+
2553+
// 4. tensor.insert_slice - inspect the destination tensor. If it's larger
2554+
// than the input tensor, then, provided it's constant, we'll extract the
2555+
// value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2556+
// TODO: Clarify the semantics when the input tensor is larger than the
2557+
// destination.
2558+
if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2559+
return getStaticPadVl(slice.getDest().getDefiningOp());
2560+
2561+
return {};
2562+
}
2563+
2564+
/// Rewrite tensor.insert.slice as a vector.transfer_read +
2565+
/// vector.transfer_write pair. The vector size is inferred from the static
2566+
/// dims in the input and output tensors. If a dim is dynamic in both the input
2567+
/// and output tensors, bails out.
2568+
///
2569+
/// Before:
2570+
/// !t_in_type = tensor<1x2x3xf32>
2571+
/// !t_out_type = tensor<9x8x7x1x2x3xf32>
2572+
/// !v_type = vector<1x2x3xf32>
2573+
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2574+
/// into !t_out_type
2575+
/// After:
2576+
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2577+
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2578+
///
2579+
/// TODO: Support masking
2580+
struct InsertSliceVectorizePattern
2581+
: public OpRewritePattern<tensor::InsertSliceOp> {
2582+
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2583+
2584+
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
2585+
PatternRewriter &rewriter) const final {
2586+
auto sourceType = sliceOp.getSource().getType();
2587+
if (!VectorType::isValidElementType(sourceType.getElementType()))
2588+
return failure();
2589+
2590+
auto resultType = sliceOp.getResultType();
2591+
2592+
// 1. Get the pad value.
2593+
// TransferReadOp requires a scalar padding value. Note that:
2594+
// * for in-bounds access, the value is actually irrelevant.
2595+
// There are 2 cases in which xfer.read accesses are known to be in-bounds:
2596+
// 1. The source shape is static (output vector sizes would be based on
2597+
// the source shape and hence all memory accesses would be in-bounds),
2598+
// 2. Masking is used (output vector sizes would be user-provided, in which
2599+
// case it is assumed that all memory accesses are in-bounds). This
2600+
// remains a TODO.
2601+
//
2602+
// When the value is not known and not needed, use 0. Otherwise, bail out.
2603+
Value padValue = getStaticPadVl(sliceOp);
2604+
bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2605+
2606+
if (!padValue && isOutOfBoundsRead) {
2607+
LDBG("Failed to get a pad value for out-of-bounds read access\n");
2608+
return failure();
2609+
}
2610+
2611+
if (!padValue) {
2612+
auto elemType = sourceType.getElementType();
2613+
padValue = rewriter.create<arith::ConstantOp>(
2614+
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2615+
}
2616+
2617+
// 2. Get the vector shape and in-bounds attributes
2618+
SmallVector<int64_t> vecShape;
2619+
SmallVector<bool> readInBounds;
2620+
SmallVector<bool> writeInBounds;
2621+
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2622+
if (!sourceType.isDynamicDim(i)) {
2623+
vecShape.push_back(sourceType.getDimSize(i));
2624+
// Source shape is statically known: Neither read nor write are
2625+
// out-of-bounds.
2626+
readInBounds.push_back(true);
2627+
writeInBounds.push_back(true);
2628+
} else if (!resultType.isDynamicDim(i)) {
2629+
// Source shape is not statically known, but result shape is.
2630+
// Vectorize with size of result shape. This may be larger than the
2631+
// source size.
2632+
vecShape.push_back(resultType.getDimSize(i));
2633+
// Read may be out-of-bounds because the result size could be larger
2634+
// than the source size.
2635+
readInBounds.push_back(false);
2636+
// Write will in-bounds provided that the corresponding write idx is 0.
2637+
// To keep this logic simple, conservatively mark as out-of-bounds.
2638+
writeInBounds.push_back(false);
2639+
} else {
2640+
// Neither source nor result dim of padOp is static. Cannot vectorize
2641+
// the copy.
2642+
// TODO: Add support for masking
2643+
return failure();
2644+
}
2645+
}
2646+
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2647+
2648+
// 3. Generate TransferReadOp.
2649+
SmallVector<Value> readIndices(
2650+
vecType.getRank(),
2651+
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2652+
auto read = rewriter.create<vector::TransferReadOp>(
2653+
sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2654+
ArrayRef<bool>{readInBounds});
2655+
2656+
// 4. Generate TransferWriteOp.
2657+
auto writeIndices =
2658+
ofrToIndexValues(rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2659+
2660+
// 5. Finalize
2661+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2662+
sliceOp, read, sliceOp.getDest(), writeIndices,
2663+
ArrayRef<bool>{writeInBounds});
2664+
2665+
return success();
2666+
}
2667+
};
2668+
26072669
/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
26082670
/// ```
26092671
/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
@@ -2691,10 +2753,14 @@ struct PadOpVectorizationWithInsertSlicePattern
26912753
}
26922754
};
26932755

2756+
void mlir::linalg::populateInsertSliceVectorizationPatterns(
2757+
RewritePatternSet &patterns) {
2758+
patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
2759+
}
2760+
26942761
void mlir::linalg::populatePadOpVectorizationPatterns(
26952762
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2696-
patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2697-
baseBenefit);
2763+
patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit);
26982764
// Try these specialized patterns first before resorting to the generic one.
26992765
patterns.add<PadOpVectorizationWithTransferReadPattern,
27002766
PadOpVectorizationWithTransferWritePattern,

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,15 @@ func.func @illegal_fill_tensor_with_memref_return
352352

353353
// -----
354354

355+
func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
356+
{
357+
// expected-error @+1 {{expected op with scalar input}}
358+
%0 = linalg.fill ins(%arg1 : tensor<2xf32>) outs(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32>
359+
return %0 : tensor<2x2xf32>
360+
}
361+
362+
// -----
363+
355364
func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
356365
// expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 3}}
357366
linalg.matmul ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)

0 commit comments

Comments
 (0)