Skip to content

Commit 39ad84e

Browse files
authored
[mlir][linalg] Split GenericPadOpVectorizationPattern into two patterns (#111349)
At the moment, `GenericPadOpVectorizationPattern` implements two orthogonal transformations: 1. Rewrites `tensor::PadOp` into a sequence of `tensor::EmptyOp`, `linalg::FillOp` and `tensor::InsertSliceOp`. 2. Vectorizes (where possible) `tensor::InsertSliceOp` (see `tryVectorizeCopy`). This patch splits `GenericPadOpVectorizationPattern` into two separate patterns: 1. `GeneralizePadOpPattern` for the first transformation (note that currently `GenericPadOpVectorizationPattern` inherits from `GeneralizePadOpPattern`). 2. `InsertSliceVectorizePattern` to vectorize `tensor::InsertSliceOp`. With this change, we gain the following: * a clear separation between pre-processing and vectorization transformations/stages, * a path to support masked vectorisation for `tensor.insert_slice` (with a dedicated pattern for vectorization, it is much easier to specify the input vector sizes used in masking), * more opportunities to vectorize `tensor.insert_slice`. Note for downstream users: -------------------------- If you were using `populatePadOpVectorizationPatterns`, following this change you will also have to add `populateInsertSliceVectorizationPatterns`. Finer implementation details: ----------------------------- 1. The majority of changes in this patch are copy & paste + some edits. 1.1. The only functional change is that the vectorization of `tensor.insert_slice` is now broadly available (as opposed to being constrained to the pad vectorization pattern: `GenericPadOpVectorizationPattern`). 1.2. Following-on from the above, `@pad_and_insert_slice_dest` is updated. As expected, the input `tensor.insert_slice` Op is no longer "preserved" and instead gets vectorized successfully. 2. The `linalg.fill` case in `getConstantPadVal` works under the assumption that only _scalar_ source values can be used. That's consistent with the definition of the Op, but it's not tested at the moment. Hence a test case in Linalg/invalid.mlir is added. 3. The behaviour of the two TD vectorization Ops, `transform.structured.vectorize_children_and_apply_patterns` and `transform.structured.vectorize` is preserved.
1 parent 2a9dd8a commit 39ad84e

File tree

8 files changed

+327
-141
lines changed

8 files changed

+327
-141
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,18 +1503,13 @@ using OptimizeCopyFn =
15031503

15041504
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
15051505
/// InsertSliceOp. For now, only constant padding values are supported.
1506-
/// `OptimizeCopyFn` can be used to customize copying step optimization.
15071506
struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
1508-
GeneralizePadOpPattern(MLIRContext *context,
1509-
OptimizeCopyFn optimizeCopyFn = nullptr,
1510-
PatternBenefit benefit = 1)
1511-
: OpRewritePattern<tensor::PadOp>(context, benefit),
1512-
optimizeCopyFn(std::move(optimizeCopyFn)) {}
1507+
GeneralizePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
1508+
: OpRewritePattern<tensor::PadOp>(context, benefit) {}
15131509
LogicalResult matchAndRewrite(tensor::PadOp padOp,
15141510
PatternRewriter &rewriter) const override;
15151511

15161512
protected:
1517-
OptimizeCopyFn optimizeCopyFn;
15181513
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp,
15191514
Value dest,
15201515
const SmallVector<Value> &dynSizes) const;
@@ -1663,6 +1658,11 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
16631658
/// \see rewriteInIm2Col for more details.
16641659
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
16651660

1661+
/// Populates `patterns` with vectorisation patterns for tensor.insert_slice.
1662+
/// TODO: Avoid having a dedicated `populate{}` for one pattern. Instead, either
1663+
/// expand or merge with other `populate{}`.
1664+
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns);
1665+
16661666
/// Populates `patterns` with patterns that vectorize tensor.pad.
16671667
/// These patterns are meant to apply in a complementary fashion. Benefits
16681668
/// are used to encode a certain ordering of pattern application. To avoid

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
256256
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
257257
RewritePatternSet &patterns) {
258258
linalg::populatePadOpVectorizationPatterns(patterns);
259+
linalg::populateInsertSliceVectorizationPatterns(patterns);
259260
}
260261

261262
//===----------------------------------------------------------------------===//
@@ -3482,6 +3483,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
34823483

34833484
patterns.add<CopyVectorizationPattern>(ctx);
34843485

3486+
// Add misc. vectorization patterns (e.g. for tensor.insert_slice)
3487+
linalg::populateInsertSliceVectorizationPatterns(patterns);
3488+
34853489
if (getVectorizePadding())
34863490
linalg::populatePadOpVectorizationPatterns(patterns);
34873491

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -973,12 +973,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
973973
padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
974974
Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
975975

976-
// Try optimize the copy of source.
977-
if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
978-
return success();
979-
980-
// tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
981-
// for copying the PadOp source.
976+
// Generate a InsertSliceOp for copying the PadOp source.
982977
auto sourceType = padOp.getSourceType();
983978
// Compute size of source of tensor::PadOp.
984979
SmallVector<OpFoldResult> srcSizes =

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 166 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,115 +2281,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
22812281
//----------------------------------------------------------------------------//
22822282
// Misc. vectorization patterns.
22832283
//----------------------------------------------------------------------------//
2284-
2285-
/// Helper function that retrieves the value of an IntegerAttr.
2286-
static int64_t getIntFromAttr(Attribute attr) {
2287-
return cast<IntegerAttr>(attr).getInt();
2288-
}
2289-
2290-
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
2291-
/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2292-
/// not supported.
2293-
static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
2294-
ArrayRef<OpFoldResult> ofrs) {
2295-
SmallVector<Value> result;
2296-
for (auto o : ofrs) {
2297-
if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2298-
result.push_back(val);
2299-
} else {
2300-
result.push_back(rewriter.create<arith::ConstantIndexOp>(
2301-
loc, getIntFromAttr(o.template get<Attribute>())));
2302-
}
2303-
}
2304-
return result;
2305-
}
2306-
2307-
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2308-
/// InsertSliceOp. For now, only constant padding values are supported.
2309-
/// If there is enough static type information, TransferReadOps and
2310-
/// TransferWriteOps may be generated instead of InsertSliceOps.
2311-
struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2312-
GenericPadOpVectorizationPattern(MLIRContext *context,
2313-
PatternBenefit benefit = 1)
2314-
: GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2315-
/// Vectorize the copying of a tensor::PadOp's source. This is possible if
2316-
/// each dimension size is statically know in the source type or the result
2317-
/// type (or both).
2318-
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
2319-
tensor::PadOp padOp, Value dest) {
2320-
auto sourceType = padOp.getSourceType();
2321-
auto resultType = padOp.getResultType();
2322-
if (!VectorType::isValidElementType(sourceType.getElementType()))
2323-
return failure();
2324-
2325-
// Copy cannot be vectorized if pad value is non-constant and source shape
2326-
// is dynamic. In case of a dynamic source shape, padding must be appended
2327-
// by TransferReadOp, but TransferReadOp supports only constant padding.
2328-
auto padValue = padOp.getConstantPaddingValue();
2329-
if (!padValue) {
2330-
if (!sourceType.hasStaticShape())
2331-
return failure();
2332-
// Create dummy padding value.
2333-
auto elemType = sourceType.getElementType();
2334-
padValue = rewriter.create<arith::ConstantOp>(
2335-
padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2336-
}
2337-
2338-
SmallVector<int64_t> vecShape;
2339-
SmallVector<bool> readInBounds;
2340-
SmallVector<bool> writeInBounds;
2341-
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2342-
if (!sourceType.isDynamicDim(i)) {
2343-
vecShape.push_back(sourceType.getDimSize(i));
2344-
// Source shape is statically known: Neither read nor write are
2345-
// out-of- bounds.
2346-
readInBounds.push_back(true);
2347-
writeInBounds.push_back(true);
2348-
} else if (!resultType.isDynamicDim(i)) {
2349-
// Source shape is not statically known, but result shape is.
2350-
// Vectorize with size of result shape. This may be larger than the
2351-
// source size.
2352-
vecShape.push_back(resultType.getDimSize(i));
2353-
// Read may be out-of-bounds because the result size could be larger
2354-
// than the source size.
2355-
readInBounds.push_back(false);
2356-
// Write is out-of-bounds if low padding > 0.
2357-
writeInBounds.push_back(
2358-
getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2359-
static_cast<int64_t>(0));
2360-
} else {
2361-
// Neither source nor result dim of padOp is static. Cannot vectorize
2362-
// the copy.
2363-
return failure();
2364-
}
2365-
}
2366-
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2367-
2368-
// Generate TransferReadOp.
2369-
SmallVector<Value> readIndices(
2370-
vecType.getRank(),
2371-
rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2372-
auto read = rewriter.create<vector::TransferReadOp>(
2373-
padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2374-
ArrayRef<bool>{readInBounds});
2375-
2376-
// If `dest` is a FillOp and the TransferWriteOp would overwrite the
2377-
// entire tensor, write directly to the FillOp's operand.
2378-
if (llvm::equal(vecShape, resultType.getShape()) &&
2379-
llvm::all_of(writeInBounds, [](bool b) { return b; }))
2380-
if (auto fill = dest.getDefiningOp<FillOp>())
2381-
dest = fill.output();
2382-
2383-
// Generate TransferWriteOp.
2384-
auto writeIndices =
2385-
ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2386-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2387-
padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2388-
2389-
return success();
2390-
}
2391-
};
2392-
23932284
/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
23942285
/// given operation type OpTy.
23952286
template <typename OpTy>
@@ -2623,6 +2514,163 @@ struct PadOpVectorizationWithTransferWritePattern
26232514
}
26242515
};
26252516

2517+
/// Returns the effective Pad value for the input op, provided it's a scalar.
2518+
///
2519+
/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2520+
/// this Op performs padding, retrieve the padding value provided that it's
2521+
/// a scalar and static/fixed for all the padded values. Returns an empty value
2522+
/// otherwise.
2523+
static Value getStaticPadVal(Operation *op) {
2524+
if (!op)
2525+
return {};
2526+
2527+
// 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2528+
// being broadcast, provided that it's a scalar.
2529+
if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2530+
auto source = bcast.getSource();
2531+
if (llvm::dyn_cast<VectorType>(source.getType()))
2532+
return {};
2533+
2534+
return source;
2535+
}
2536+
2537+
// 2. linalg.fill - use the scalar input value that used to fill the output
2538+
// tensor.
2539+
if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2540+
return fill.getInputs()[0];
2541+
}
2542+
2543+
// 3. tensor.generateOp - can't guarantee the value is fixed without
2544+
// analysing, bail out.
2545+
if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2546+
return {};
2547+
}
2548+
2549+
// 4. vector.transfer_write - inspect the input vector that's written from. If
2550+
// if contains a single value that has been broadcast (e.g. via
2551+
// vector.broadcast), extract it, fail otherwise.
2552+
if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2553+
return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2554+
2555+
// 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2556+
// than the input tensor, then, provided it's constant, we'll extract the
2557+
// value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2558+
// TODO: Clarify the semantics when the input tensor is larger than the
2559+
// destination.
2560+
if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2561+
return getStaticPadVal(slice.getDest().getDefiningOp());
2562+
2563+
return {};
2564+
}
2565+
2566+
/// Rewrite tensor.insert.slice as a vector.transfer_read +
2567+
/// vector.transfer_write pair. The vector size is inferred from the static
2568+
/// dims in the input and output tensors. If a dim is dynamic in both the input
2569+
/// and output tensors, bails out.
2570+
///
2571+
/// Before:
2572+
/// !t_in_type = tensor<1x2x3xf32>
2573+
/// !t_out_type = tensor<9x8x7x1x2x3xf32>
2574+
/// !v_type = vector<1x2x3xf32>
2575+
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2576+
/// into !t_out_type
2577+
/// After:
2578+
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2579+
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2580+
///
2581+
/// TODO: Support masking
2582+
struct InsertSliceVectorizePattern
2583+
: public OpRewritePattern<tensor::InsertSliceOp> {
2584+
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2585+
2586+
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
2587+
PatternRewriter &rewriter) const final {
2588+
auto sourceType = sliceOp.getSource().getType();
2589+
if (!VectorType::isValidElementType(sourceType.getElementType()))
2590+
return failure();
2591+
2592+
auto resultType = sliceOp.getResultType();
2593+
2594+
// 1. Get the pad value.
2595+
// TransferReadOp requires a scalar padding value. Note that:
2596+
// * for in-bounds access, the value is actually irrelevant.
2597+
// There are 2 cases in which xfer.read accesses are known to be in-bounds:
2598+
// 1. The source shape is static (output vector sizes would be based on
2599+
// the source shape and hence all memory accesses would be in-bounds),
2600+
// 2. Masking is used (output vector sizes would be user-provided, in which
2601+
// case it is assumed that all memory accesses are in-bounds). This
2602+
// remains a TODO.
2603+
//
2604+
// When the value is not known and not needed, use 0. Otherwise, bail out.
2605+
Value padValue = getStaticPadVal(sliceOp);
2606+
bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2607+
2608+
if (!padValue && isOutOfBoundsRead) {
2609+
LDBG("Failed to get a pad value for out-of-bounds read access\n");
2610+
return failure();
2611+
}
2612+
2613+
if (!padValue) {
2614+
auto elemType = sourceType.getElementType();
2615+
padValue = rewriter.create<arith::ConstantOp>(
2616+
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2617+
}
2618+
2619+
// 2. Get the vector shape and in-bounds attributes
2620+
SmallVector<int64_t> vecShape;
2621+
SmallVector<bool> readInBounds;
2622+
SmallVector<bool> writeInBounds;
2623+
size_t rankDiff = resultType.getRank() - sourceType.getRank();
2624+
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2625+
if (!sourceType.isDynamicDim(i)) {
2626+
vecShape.push_back(sourceType.getDimSize(i));
2627+
// Source shape is statically known: Neither read nor write are
2628+
// out-of-bounds.
2629+
readInBounds.push_back(true);
2630+
writeInBounds.push_back(true);
2631+
} else if (!resultType.isDynamicDim(i)) {
2632+
// Source shape is not statically known, but result shape is.
2633+
// Vectorize with size of result shape. This may be larger than the
2634+
// source size.
2635+
// FIXME: Using rankDiff implies that the source tensor is inserted at
2636+
// the end of the destination tensor. However, that's not required.
2637+
vecShape.push_back(resultType.getDimSize(rankDiff + i));
2638+
// Read may be out-of-bounds because the result size could be larger
2639+
// than the source size.
2640+
readInBounds.push_back(false);
2641+
// Write will in-bounds provided that the corresponding write idx is 0.
2642+
// To keep this logic simple, conservatively mark as out-of-bounds.
2643+
writeInBounds.push_back(false);
2644+
} else {
2645+
// Neither source nor result dim of padOp is static. Cannot vectorize
2646+
// the copy.
2647+
// TODO: Add support for masking
2648+
return failure();
2649+
}
2650+
}
2651+
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2652+
2653+
// 3. Generate TransferReadOp.
2654+
SmallVector<Value> readIndices(
2655+
vecType.getRank(),
2656+
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2657+
auto read = rewriter.create<vector::TransferReadOp>(
2658+
sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2659+
ArrayRef<bool>{readInBounds});
2660+
2661+
// 4. Generate TransferWriteOp.
2662+
auto writeIndices = getValueOrCreateConstantIndexOp(
2663+
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2664+
2665+
// 5. Finalize
2666+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2667+
sliceOp, read, sliceOp.getDest(), writeIndices,
2668+
ArrayRef<bool>{writeInBounds});
2669+
2670+
return success();
2671+
}
2672+
};
2673+
26262674
/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
26272675
/// ```
26282676
/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
@@ -2699,8 +2747,8 @@ struct PadOpVectorizationWithInsertSlicePattern
26992747
// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
27002748
// specified offsets. Write is fully in-bounds because a InsertSliceOp's
27012749
// source must fit into the destination at the specified offsets.
2702-
auto writeIndices =
2703-
ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2750+
auto writeIndices = getValueOrCreateConstantIndexOp(
2751+
rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
27042752
SmallVector<bool> inBounds(vecRank, true);
27052753
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
27062754
insertOp, read, insertOp.getDest(), writeIndices,
@@ -2710,13 +2758,18 @@ struct PadOpVectorizationWithInsertSlicePattern
27102758
}
27112759
};
27122760

2761+
void mlir::linalg::populateInsertSliceVectorizationPatterns(
2762+
RewritePatternSet &patterns) {
2763+
patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
2764+
}
2765+
27132766
void mlir::linalg::populatePadOpVectorizationPatterns(
27142767
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
27152768
// TODO: The following pattern implements "decomposition" and
27162769
// optional "vectorization". Seperate "decomposition" into a sepereate
27172770
// pre-processing pattern group.
2718-
patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2719-
baseBenefit);
2771+
patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit);
2772+
27202773
// Try these specialized patterns first before resorting to the generic one.
27212774
patterns.add<PadOpVectorizationWithTransferReadPattern,
27222775
PadOpVectorizationWithTransferWritePattern,

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,15 @@ func.func @illegal_fill_tensor_with_memref_return
352352

353353
// -----
354354

355+
func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
356+
{
357+
// expected-error @+1 {{expected op with scalar input}}
358+
%0 = linalg.fill ins(%arg1 : tensor<2xf32>) outs(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32>
359+
return %0 : tensor<2x2xf32>
360+
}
361+
362+
// -----
363+
355364
func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
356365
// expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 3}}
357366
linalg.matmul ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)

0 commit comments

Comments
 (0)