@@ -2262,115 +2262,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2262
2262
// ----------------------------------------------------------------------------//
2263
2263
// Misc. vectorization patterns.
2264
2264
// ----------------------------------------------------------------------------//
2265
-
2266
- // / Helper function that retrieves the value of an IntegerAttr.
2267
- static int64_t getIntFromAttr (Attribute attr) {
2268
- return cast<IntegerAttr>(attr).getInt ();
2269
- }
2270
-
2271
- // / Given an ArrayRef of OpFoldResults, return a vector of Values.
2272
- // / IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2273
- // / not supported.
2274
- static SmallVector<Value> ofrToIndexValues (RewriterBase &rewriter, Location loc,
2275
- ArrayRef<OpFoldResult> ofrs) {
2276
- SmallVector<Value> result;
2277
- for (auto o : ofrs) {
2278
- if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2279
- result.push_back (val);
2280
- } else {
2281
- result.push_back (rewriter.create <arith::ConstantIndexOp>(
2282
- loc, getIntFromAttr (o.template get <Attribute>())));
2283
- }
2284
- }
2285
- return result;
2286
- }
2287
-
2288
- // / Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2289
- // / InsertSliceOp. For now, only constant padding values are supported.
2290
- // / If there is enough static type information, TransferReadOps and
2291
- // / TransferWriteOps may be generated instead of InsertSliceOps.
2292
- struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2293
- GenericPadOpVectorizationPattern (MLIRContext *context,
2294
- PatternBenefit benefit = 1 )
2295
- : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2296
- // / Vectorize the copying of a tensor::PadOp's source. This is possible if
2297
- // / each dimension size is statically know in the source type or the result
2298
- // / type (or both).
2299
- static LogicalResult tryVectorizeCopy (RewriterBase &rewriter,
2300
- tensor::PadOp padOp, Value dest) {
2301
- auto sourceType = padOp.getSourceType ();
2302
- auto resultType = padOp.getResultType ();
2303
- if (!VectorType::isValidElementType (sourceType.getElementType ()))
2304
- return failure ();
2305
-
2306
- // Copy cannot be vectorized if pad value is non-constant and source shape
2307
- // is dynamic. In case of a dynamic source shape, padding must be appended
2308
- // by TransferReadOp, but TransferReadOp supports only constant padding.
2309
- auto padValue = padOp.getConstantPaddingValue ();
2310
- if (!padValue) {
2311
- if (!sourceType.hasStaticShape ())
2312
- return failure ();
2313
- // Create dummy padding value.
2314
- auto elemType = sourceType.getElementType ();
2315
- padValue = rewriter.create <arith::ConstantOp>(
2316
- padOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2317
- }
2318
-
2319
- SmallVector<int64_t > vecShape;
2320
- SmallVector<bool > readInBounds;
2321
- SmallVector<bool > writeInBounds;
2322
- for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2323
- if (!sourceType.isDynamicDim (i)) {
2324
- vecShape.push_back (sourceType.getDimSize (i));
2325
- // Source shape is statically known: Neither read nor write are
2326
- // out-of- bounds.
2327
- readInBounds.push_back (true );
2328
- writeInBounds.push_back (true );
2329
- } else if (!resultType.isDynamicDim (i)) {
2330
- // Source shape is not statically known, but result shape is.
2331
- // Vectorize with size of result shape. This may be larger than the
2332
- // source size.
2333
- vecShape.push_back (resultType.getDimSize (i));
2334
- // Read may be out-of-bounds because the result size could be larger
2335
- // than the source size.
2336
- readInBounds.push_back (false );
2337
- // Write is out-of-bounds if low padding > 0.
2338
- writeInBounds.push_back (
2339
- getConstantIntValue (padOp.getMixedLowPad ()[i]) ==
2340
- static_cast <int64_t >(0 ));
2341
- } else {
2342
- // Neither source nor result dim of padOp is static. Cannot vectorize
2343
- // the copy.
2344
- return failure ();
2345
- }
2346
- }
2347
- auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2348
-
2349
- // Generate TransferReadOp.
2350
- SmallVector<Value> readIndices (
2351
- vecType.getRank (),
2352
- rewriter.create <arith::ConstantIndexOp>(padOp.getLoc (), 0 ));
2353
- auto read = rewriter.create <vector::TransferReadOp>(
2354
- padOp.getLoc (), vecType, padOp.getSource (), readIndices, padValue,
2355
- ArrayRef<bool >{readInBounds});
2356
-
2357
- // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2358
- // entire tensor, write directly to the FillOp's operand.
2359
- if (llvm::equal (vecShape, resultType.getShape ()) &&
2360
- llvm::all_of (writeInBounds, [](bool b) { return b; }))
2361
- if (auto fill = dest.getDefiningOp <FillOp>())
2362
- dest = fill.output ();
2363
-
2364
- // Generate TransferWriteOp.
2365
- auto writeIndices =
2366
- ofrToIndexValues (rewriter, padOp.getLoc (), padOp.getMixedLowPad ());
2367
- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2368
- padOp, read, dest, writeIndices, ArrayRef<bool >{writeInBounds});
2369
-
2370
- return success ();
2371
- }
2372
- };
2373
-
2374
2265
// / Base pattern for rewriting tensor::PadOps whose result is consumed by a
2375
2266
// / given operation type OpTy.
2376
2267
template <typename OpTy>
@@ -2604,6 +2495,177 @@ struct PadOpVectorizationWithTransferWritePattern
2604
2495
}
2605
2496
};
2606
2497
2498
+ // / Given an ArrayRef of OpFoldResults, return a vector of Values.
2499
+ // / IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2500
+ // / not supported.
2501
+ static SmallVector<Value> ofrToIndexValues (RewriterBase &rewriter, Location loc,
2502
+ ArrayRef<OpFoldResult> ofrs) {
2503
+ SmallVector<Value> result;
2504
+ for (auto o : ofrs) {
2505
+ if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2506
+ result.push_back (val);
2507
+ } else {
2508
+ result.push_back (rewriter.create <arith::ConstantIndexOp>(
2509
+ loc, cast<IntegerAttr>(cast<Attribute>(o)).getInt ()));
2510
+ }
2511
+ }
2512
+ return result;
2513
+ }
2514
+
2515
+ // / Returns the effective Pad value for the input op, provided it's a scalar.
2516
+ // /
2517
+ // / Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2518
+ // / this Op performs padding, retrieve the padding value provided that it's
2519
+ // / a scalar and static/fixed for all the padded values. Returns an empty value
2520
+ // / otherwise.
2521
+ static Value getStaticPadVl (Operation *op) {
2522
+ if (!op)
2523
+ return {};
2524
+
2525
+ // 1. vector.broadcast - return the value that's being broadcast,
2526
+ // provided that it's a scalar.
2527
+ if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2528
+ auto source = bcast.getSource ();
2529
+ if (llvm::dyn_cast<VectorType>(source.getType ()))
2530
+ return {};
2531
+
2532
+ return source;
2533
+ }
2534
+
2535
+ // 1. linalg.fill - use the scalar input value that used to fill the output
2536
+ // tensor.
2537
+ if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2538
+ return fill.getInputs ()[0 ];
2539
+ }
2540
+
2541
+ // 2. tensor.generateOp - can't guarantee the value is fixed without
2542
+ // analysing, bail out.
2543
+ if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2544
+ return {};
2545
+ }
2546
+
2547
+ // 3. vector.transfer_write - inspect the input vector that's written from. If
2548
+ // if contains a single value that has been broadcast (e.g. via
2549
+ // vector.broadcast), extract it, fail otherwise.
2550
+ if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2551
+ return getStaticPadVl (xferWrite.getVector ().getDefiningOp ());
2552
+
2553
+ // 4. tensor.insert_slice - inspect the destination tensor. If it's larger
2554
+ // than the input tensor, then, provided it's constant, we'll extract the
2555
+ // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2556
+ // TODO: Clarify the semantics when the input tensor is larger than the
2557
+ // destination.
2558
+ if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2559
+ return getStaticPadVl (slice.getDest ().getDefiningOp ());
2560
+
2561
+ return {};
2562
+ }
2563
+
2564
+ // / Rewrite tensor.insert.slice as a vector.transfer_read +
2565
+ // / vector.transfer_write pair. The vector size is inferred from the static
2566
+ // / dims in the input and output tensors. If a dim is dynamic in both the input
2567
+ // / and output tensors, bails out.
2568
+ // /
2569
+ // / Before:
2570
+ // / !t_in_type = tensor<1x2x3xf32>
2571
+ // / !t_out_type = tensor<9x8x7x1x2x3xf32>
2572
+ // / !v_type = vector<1x2x3xf32>
2573
+ // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2574
+ // / into !t_out_type
2575
+ // / After:
2576
+ // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2577
+ // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2578
+ // /
2579
+ // / TODO: Support masking
2580
+ struct InsertSliceVectorizePattern
2581
+ : public OpRewritePattern<tensor::InsertSliceOp> {
2582
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2583
+
2584
+ LogicalResult matchAndRewrite (tensor::InsertSliceOp sliceOp,
2585
+ PatternRewriter &rewriter) const final {
2586
+ auto sourceType = sliceOp.getSource ().getType ();
2587
+ if (!VectorType::isValidElementType (sourceType.getElementType ()))
2588
+ return failure ();
2589
+
2590
+ auto resultType = sliceOp.getResultType ();
2591
+
2592
+ // 1. Get the pad value.
2593
+ // TransferReadOp requires a scalar padding value. Note that:
2594
+ // * for in-bounds access, the value is actually irrelevant.
2595
+ // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2596
+ // 1. The source shape is static (output vector sizes would be based on
2597
+ // the source shape and hence all memory accesses would be in-bounds),
2598
+ // 2. Masking is used (output vector sizes would be user-provided, in which
2599
+ // case it is assumed that all memory accesses are in-bounds). This
2600
+ // remains a TODO.
2601
+ //
2602
+ // When the value is not known and not needed, use 0. Otherwise, bail out.
2603
+ Value padValue = getStaticPadVl (sliceOp);
2604
+ bool isOutOfBoundsRead = !sourceType.hasStaticShape ();
2605
+
2606
+ if (!padValue && isOutOfBoundsRead) {
2607
+ LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2608
+ return failure ();
2609
+ }
2610
+
2611
+ if (!padValue) {
2612
+ auto elemType = sourceType.getElementType ();
2613
+ padValue = rewriter.create <arith::ConstantOp>(
2614
+ sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2615
+ }
2616
+
2617
+ // 2. Get the vector shape and in-bounds attributes
2618
+ SmallVector<int64_t > vecShape;
2619
+ SmallVector<bool > readInBounds;
2620
+ SmallVector<bool > writeInBounds;
2621
+ for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2622
+ if (!sourceType.isDynamicDim (i)) {
2623
+ vecShape.push_back (sourceType.getDimSize (i));
2624
+ // Source shape is statically known: Neither read nor write are
2625
+ // out-of-bounds.
2626
+ readInBounds.push_back (true );
2627
+ writeInBounds.push_back (true );
2628
+ } else if (!resultType.isDynamicDim (i)) {
2629
+ // Source shape is not statically known, but result shape is.
2630
+ // Vectorize with size of result shape. This may be larger than the
2631
+ // source size.
2632
+ vecShape.push_back (resultType.getDimSize (i));
2633
+ // Read may be out-of-bounds because the result size could be larger
2634
+ // than the source size.
2635
+ readInBounds.push_back (false );
2636
+ // Write will in-bounds provided that the corresponding write idx is 0.
2637
+ // To keep this logic simple, conservatively mark as out-of-bounds.
2638
+ writeInBounds.push_back (false );
2639
+ } else {
2640
+ // Neither source nor result dim of padOp is static. Cannot vectorize
2641
+ // the copy.
2642
+ // TODO: Add support for masking
2643
+ return failure ();
2644
+ }
2645
+ }
2646
+ auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2647
+
2648
+ // 3. Generate TransferReadOp.
2649
+ SmallVector<Value> readIndices (
2650
+ vecType.getRank (),
2651
+ rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2652
+ auto read = rewriter.create <vector::TransferReadOp>(
2653
+ sliceOp.getLoc (), vecType, sliceOp.getSource (), readIndices, padValue,
2654
+ ArrayRef<bool >{readInBounds});
2655
+
2656
+ // 4. Generate TransferWriteOp.
2657
+ auto writeIndices =
2658
+ ofrToIndexValues (rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2659
+
2660
+ // 5. Finalize
2661
+ rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2662
+ sliceOp, read, sliceOp.getDest (), writeIndices,
2663
+ ArrayRef<bool >{writeInBounds});
2664
+
2665
+ return success ();
2666
+ }
2667
+ };
2668
+
2607
2669
// / Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2608
2670
// / ```
2609
2671
// / %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
@@ -2691,10 +2753,14 @@ struct PadOpVectorizationWithInsertSlicePattern
2691
2753
}
2692
2754
};
2693
2755
2756
+ void mlir::linalg::populateInsertSliceVectorizationPatterns (
2757
+ RewritePatternSet &patterns) {
2758
+ patterns.add <InsertSliceVectorizePattern>(patterns.getContext ());
2759
+ }
2760
+
2694
2761
void mlir::linalg::populatePadOpVectorizationPatterns (
2695
2762
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2696
- patterns.add <GenericPadOpVectorizationPattern>(patterns.getContext (),
2697
- baseBenefit);
2763
+ patterns.add <GeneralizePadOpPattern>(patterns.getContext (), baseBenefit);
2698
2764
// Try these specialized patterns first before resorting to the generic one.
2699
2765
patterns.add <PadOpVectorizationWithTransferReadPattern,
2700
2766
PadOpVectorizationWithTransferWritePattern,
0 commit comments