@@ -2281,115 +2281,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2281
2281
// ----------------------------------------------------------------------------//
2282
2282
// Misc. vectorization patterns.
2283
2283
// ----------------------------------------------------------------------------//
2284
-
2285
- // / Helper function that retrieves the value of an IntegerAttr.
2286
- static int64_t getIntFromAttr (Attribute attr) {
2287
- return cast<IntegerAttr>(attr).getInt ();
2288
- }
2289
-
2290
- // / Given an ArrayRef of OpFoldResults, return a vector of Values.
2291
- // / IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2292
- // / not supported.
2293
- static SmallVector<Value> ofrToIndexValues (RewriterBase &rewriter, Location loc,
2294
- ArrayRef<OpFoldResult> ofrs) {
2295
- SmallVector<Value> result;
2296
- for (auto o : ofrs) {
2297
- if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2298
- result.push_back (val);
2299
- } else {
2300
- result.push_back (rewriter.create <arith::ConstantIndexOp>(
2301
- loc, getIntFromAttr (o.template get <Attribute>())));
2302
- }
2303
- }
2304
- return result;
2305
- }
2306
-
2307
- // / Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2308
- // / InsertSliceOp. For now, only constant padding values are supported.
2309
- // / If there is enough static type information, TransferReadOps and
2310
- // / TransferWriteOps may be generated instead of InsertSliceOps.
2311
- struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2312
- GenericPadOpVectorizationPattern (MLIRContext *context,
2313
- PatternBenefit benefit = 1 )
2314
- : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2315
- // / Vectorize the copying of a tensor::PadOp's source. This is possible if
2316
- // / each dimension size is statically know in the source type or the result
2317
- // / type (or both).
2318
- static LogicalResult tryVectorizeCopy (RewriterBase &rewriter,
2319
- tensor::PadOp padOp, Value dest) {
2320
- auto sourceType = padOp.getSourceType ();
2321
- auto resultType = padOp.getResultType ();
2322
- if (!VectorType::isValidElementType (sourceType.getElementType ()))
2323
- return failure ();
2324
-
2325
- // Copy cannot be vectorized if pad value is non-constant and source shape
2326
- // is dynamic. In case of a dynamic source shape, padding must be appended
2327
- // by TransferReadOp, but TransferReadOp supports only constant padding.
2328
- auto padValue = padOp.getConstantPaddingValue ();
2329
- if (!padValue) {
2330
- if (!sourceType.hasStaticShape ())
2331
- return failure ();
2332
- // Create dummy padding value.
2333
- auto elemType = sourceType.getElementType ();
2334
- padValue = rewriter.create <arith::ConstantOp>(
2335
- padOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2336
- }
2337
-
2338
- SmallVector<int64_t > vecShape;
2339
- SmallVector<bool > readInBounds;
2340
- SmallVector<bool > writeInBounds;
2341
- for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2342
- if (!sourceType.isDynamicDim (i)) {
2343
- vecShape.push_back (sourceType.getDimSize (i));
2344
- // Source shape is statically known: Neither read nor write are
2345
- // out-of- bounds.
2346
- readInBounds.push_back (true );
2347
- writeInBounds.push_back (true );
2348
- } else if (!resultType.isDynamicDim (i)) {
2349
- // Source shape is not statically known, but result shape is.
2350
- // Vectorize with size of result shape. This may be larger than the
2351
- // source size.
2352
- vecShape.push_back (resultType.getDimSize (i));
2353
- // Read may be out-of-bounds because the result size could be larger
2354
- // than the source size.
2355
- readInBounds.push_back (false );
2356
- // Write is out-of-bounds if low padding > 0.
2357
- writeInBounds.push_back (
2358
- getConstantIntValue (padOp.getMixedLowPad ()[i]) ==
2359
- static_cast <int64_t >(0 ));
2360
- } else {
2361
- // Neither source nor result dim of padOp is static. Cannot vectorize
2362
- // the copy.
2363
- return failure ();
2364
- }
2365
- }
2366
- auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2367
-
2368
- // Generate TransferReadOp.
2369
- SmallVector<Value> readIndices (
2370
- vecType.getRank (),
2371
- rewriter.create <arith::ConstantIndexOp>(padOp.getLoc (), 0 ));
2372
- auto read = rewriter.create <vector::TransferReadOp>(
2373
- padOp.getLoc (), vecType, padOp.getSource (), readIndices, padValue,
2374
- ArrayRef<bool >{readInBounds});
2375
-
2376
- // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2377
- // entire tensor, write directly to the FillOp's operand.
2378
- if (llvm::equal (vecShape, resultType.getShape ()) &&
2379
- llvm::all_of (writeInBounds, [](bool b) { return b; }))
2380
- if (auto fill = dest.getDefiningOp <FillOp>())
2381
- dest = fill.output ();
2382
-
2383
- // Generate TransferWriteOp.
2384
- auto writeIndices =
2385
- ofrToIndexValues (rewriter, padOp.getLoc (), padOp.getMixedLowPad ());
2386
- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2387
- padOp, read, dest, writeIndices, ArrayRef<bool >{writeInBounds});
2388
-
2389
- return success ();
2390
- }
2391
- };
2392
-
2393
2284
// / Base pattern for rewriting tensor::PadOps whose result is consumed by a
2394
2285
// / given operation type OpTy.
2395
2286
template <typename OpTy>
@@ -2623,6 +2514,177 @@ struct PadOpVectorizationWithTransferWritePattern
2623
2514
}
2624
2515
};
2625
2516
2517
+ // / Given an ArrayRef of OpFoldResults, return a vector of Values.
2518
+ // / IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2519
+ // / not supported.
2520
+ static SmallVector<Value> ofrToIndexValues (RewriterBase &rewriter, Location loc,
2521
+ ArrayRef<OpFoldResult> ofrs) {
2522
+ SmallVector<Value> result;
2523
+ for (auto o : ofrs) {
2524
+ if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2525
+ result.push_back (val);
2526
+ } else {
2527
+ result.push_back (rewriter.create <arith::ConstantIndexOp>(
2528
+ loc, cast<IntegerAttr>(cast<Attribute>(o)).getInt ()));
2529
+ }
2530
+ }
2531
+ return result;
2532
+ }
2533
+
2534
+ // / Returns the effective Pad value for the input op, provided it's a scalar.
2535
+ // /
2536
+ // / Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2537
+ // / this Op performs padding, retrieve the padding value provided that it's
2538
+ // / a scalar and static/fixed for all the padded values. Returns an empty value
2539
+ // / otherwise.
2540
+ static Value getStaticPadVl (Operation *op) {
2541
+ if (!op)
2542
+ return {};
2543
+
2544
+ // 1. vector.broadcast - return the value that's being broadcast,
2545
+ // provided that it's a scalar.
2546
+ if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2547
+ auto source = bcast.getSource ();
2548
+ if (llvm::dyn_cast<VectorType>(source.getType ()))
2549
+ return {};
2550
+
2551
+ return source;
2552
+ }
2553
+
2554
+ // 1. linalg.fill - use the scalar input value that used to fill the output
2555
+ // tensor.
2556
+ if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2557
+ return fill.getInputs ()[0 ];
2558
+ }
2559
+
2560
+ // 2. tensor.generateOp - can't guarantee the value is fixed without
2561
+ // analysing, bail out.
2562
+ if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2563
+ return {};
2564
+ }
2565
+
2566
+ // 3. vector.transfer_write - inspect the input vector that's written from. If
2567
+ // if contains a single value that has been broadcast (e.g. via
2568
+ // vector.broadcast), extract it, fail otherwise.
2569
+ if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2570
+ return getStaticPadVl (xferWrite.getVector ().getDefiningOp ());
2571
+
2572
+ // 4. tensor.insert_slice - inspect the destination tensor. If it's larger
2573
+ // than the input tensor, then, provided it's constant, we'll extract the
2574
+ // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2575
+ // TODO: Clarify the semantics when the input tensor is larger than the
2576
+ // destination.
2577
+ if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2578
+ return getStaticPadVl (slice.getDest ().getDefiningOp ());
2579
+
2580
+ return {};
2581
+ }
2582
+
2583
+ // / Rewrite tensor.insert.slice as a vector.transfer_read +
2584
+ // / vector.transfer_write pair. The vector size is inferred from the static
2585
+ // / dims in the input and output tensors. If a dim is dynamic in both the input
2586
+ // / and output tensors, bails out.
2587
+ // /
2588
+ // / Before:
2589
+ // / !t_in_type = tensor<1x2x3xf32>
2590
+ // / !t_out_type = tensor<9x8x7x1x2x3xf32>
2591
+ // / !v_type = vector<1x2x3xf32>
2592
+ // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2593
+ // / into !t_out_type
2594
+ // / After:
2595
+ // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2596
+ // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2597
+ // /
2598
+ // / TODO: Support masking
2599
+ struct InsertSliceVectorizePattern
2600
+ : public OpRewritePattern<tensor::InsertSliceOp> {
2601
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2602
+
2603
+ LogicalResult matchAndRewrite (tensor::InsertSliceOp sliceOp,
2604
+ PatternRewriter &rewriter) const final {
2605
+ auto sourceType = sliceOp.getSource ().getType ();
2606
+ if (!VectorType::isValidElementType (sourceType.getElementType ()))
2607
+ return failure ();
2608
+
2609
+ auto resultType = sliceOp.getResultType ();
2610
+
2611
+ // 1. Get the pad value.
2612
+ // TransferReadOp requires a scalar padding value. Note that:
2613
+ // * for in-bounds access, the value is actually irrelevant.
2614
+ // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2615
+ // 1. The source shape is static (output vector sizes would be based on
2616
+ // the source shape and hence all memory accesses would be in-bounds),
2617
+ // 2. Masking is used (output vector sizes would be user-provided, in which
2618
+ // case it is assumed that all memory accesses are in-bounds). This
2619
+ // remains a TODO.
2620
+ //
2621
+ // When the value is not known and not needed, use 0. Otherwise, bail out.
2622
+ Value padValue = getStaticPadVl (sliceOp);
2623
+ bool isOutOfBoundsRead = !sourceType.hasStaticShape ();
2624
+
2625
+ if (!padValue && isOutOfBoundsRead) {
2626
+ LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2627
+ return failure ();
2628
+ }
2629
+
2630
+ if (!padValue) {
2631
+ auto elemType = sourceType.getElementType ();
2632
+ padValue = rewriter.create <arith::ConstantOp>(
2633
+ sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2634
+ }
2635
+
2636
+ // 2. Get the vector shape and in-bounds attributes
2637
+ SmallVector<int64_t > vecShape;
2638
+ SmallVector<bool > readInBounds;
2639
+ SmallVector<bool > writeInBounds;
2640
+ for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2641
+ if (!sourceType.isDynamicDim (i)) {
2642
+ vecShape.push_back (sourceType.getDimSize (i));
2643
+ // Source shape is statically known: Neither read nor write are
2644
+ // out-of-bounds.
2645
+ readInBounds.push_back (true );
2646
+ writeInBounds.push_back (true );
2647
+ } else if (!resultType.isDynamicDim (i)) {
2648
+ // Source shape is not statically known, but result shape is.
2649
+ // Vectorize with size of result shape. This may be larger than the
2650
+ // source size.
2651
+ vecShape.push_back (resultType.getDimSize (i));
2652
+ // Read may be out-of-bounds because the result size could be larger
2653
+ // than the source size.
2654
+ readInBounds.push_back (false );
2655
+ // Write will in-bounds provided that the corresponding write idx is 0.
2656
+ // To keep this logic simple, conservatively mark as out-of-bounds.
2657
+ writeInBounds.push_back (false );
2658
+ } else {
2659
+ // Neither source nor result dim of padOp is static. Cannot vectorize
2660
+ // the copy.
2661
+ // TODO: Add support for masking
2662
+ return failure ();
2663
+ }
2664
+ }
2665
+ auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2666
+
2667
+ // 3. Generate TransferReadOp.
2668
+ SmallVector<Value> readIndices (
2669
+ vecType.getRank (),
2670
+ rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2671
+ auto read = rewriter.create <vector::TransferReadOp>(
2672
+ sliceOp.getLoc (), vecType, sliceOp.getSource (), readIndices, padValue,
2673
+ ArrayRef<bool >{readInBounds});
2674
+
2675
+ // 4. Generate TransferWriteOp.
2676
+ auto writeIndices =
2677
+ ofrToIndexValues (rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2678
+
2679
+ // 5. Finalize
2680
+ rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2681
+ sliceOp, read, sliceOp.getDest (), writeIndices,
2682
+ ArrayRef<bool >{writeInBounds});
2683
+
2684
+ return success ();
2685
+ }
2686
+ };
2687
+
2626
2688
// / Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2627
2689
// / ```
2628
2690
// / %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
@@ -2710,13 +2772,18 @@ struct PadOpVectorizationWithInsertSlicePattern
2710
2772
}
2711
2773
};
2712
2774
2775
+ void mlir::linalg::populateInsertSliceVectorizationPatterns (
2776
+ RewritePatternSet &patterns) {
2777
+ patterns.add <InsertSliceVectorizePattern>(patterns.getContext ());
2778
+ }
2779
+
2713
2780
void mlir::linalg::populatePadOpVectorizationPatterns (
2714
2781
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2715
2782
// TODO: The following pattern implements "decomposition" and
2716
2783
// optional "vectorization". Seperate "decomposition" into a sepereate
2717
2784
// pre-processing pattern group.
2718
- patterns.add <GenericPadOpVectorizationPattern >(patterns.getContext (),
2719
- baseBenefit);
2785
+ patterns.add <GeneralizePadOpPattern >(patterns.getContext (), baseBenefit);
2786
+
2720
2787
// Try these specialized patterns first before resorting to the generic one.
2721
2788
patterns.add <PadOpVectorizationWithTransferReadPattern,
2722
2789
PadOpVectorizationWithTransferWritePattern,
0 commit comments