@@ -2281,115 +2281,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2281
2281
// ----------------------------------------------------------------------------//
2282
2282
// Misc. vectorization patterns.
2283
2283
// ----------------------------------------------------------------------------//
2284
-
2285
- // / Helper function that retrieves the value of an IntegerAttr.
2286
- static int64_t getIntFromAttr (Attribute attr) {
2287
- return cast<IntegerAttr>(attr).getInt ();
2288
- }
2289
-
2290
- // / Given an ArrayRef of OpFoldResults, return a vector of Values.
2291
- // / IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2292
- // / not supported.
2293
- static SmallVector<Value> ofrToIndexValues (RewriterBase &rewriter, Location loc,
2294
- ArrayRef<OpFoldResult> ofrs) {
2295
- SmallVector<Value> result;
2296
- for (auto o : ofrs) {
2297
- if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2298
- result.push_back (val);
2299
- } else {
2300
- result.push_back (rewriter.create <arith::ConstantIndexOp>(
2301
- loc, getIntFromAttr (o.template get <Attribute>())));
2302
- }
2303
- }
2304
- return result;
2305
- }
2306
-
2307
- // / Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2308
- // / InsertSliceOp. For now, only constant padding values are supported.
2309
- // / If there is enough static type information, TransferReadOps and
2310
- // / TransferWriteOps may be generated instead of InsertSliceOps.
2311
- struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2312
- GenericPadOpVectorizationPattern (MLIRContext *context,
2313
- PatternBenefit benefit = 1 )
2314
- : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2315
- // / Vectorize the copying of a tensor::PadOp's source. This is possible if
2316
- // / each dimension size is statically know in the source type or the result
2317
- // / type (or both).
2318
- static LogicalResult tryVectorizeCopy (RewriterBase &rewriter,
2319
- tensor::PadOp padOp, Value dest) {
2320
- auto sourceType = padOp.getSourceType ();
2321
- auto resultType = padOp.getResultType ();
2322
- if (!VectorType::isValidElementType (sourceType.getElementType ()))
2323
- return failure ();
2324
-
2325
- // Copy cannot be vectorized if pad value is non-constant and source shape
2326
- // is dynamic. In case of a dynamic source shape, padding must be appended
2327
- // by TransferReadOp, but TransferReadOp supports only constant padding.
2328
- auto padValue = padOp.getConstantPaddingValue ();
2329
- if (!padValue) {
2330
- if (!sourceType.hasStaticShape ())
2331
- return failure ();
2332
- // Create dummy padding value.
2333
- auto elemType = sourceType.getElementType ();
2334
- padValue = rewriter.create <arith::ConstantOp>(
2335
- padOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2336
- }
2337
-
2338
- SmallVector<int64_t > vecShape;
2339
- SmallVector<bool > readInBounds;
2340
- SmallVector<bool > writeInBounds;
2341
- for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2342
- if (!sourceType.isDynamicDim (i)) {
2343
- vecShape.push_back (sourceType.getDimSize (i));
2344
- // Source shape is statically known: Neither read nor write are
2345
- // out-of- bounds.
2346
- readInBounds.push_back (true );
2347
- writeInBounds.push_back (true );
2348
- } else if (!resultType.isDynamicDim (i)) {
2349
- // Source shape is not statically known, but result shape is.
2350
- // Vectorize with size of result shape. This may be larger than the
2351
- // source size.
2352
- vecShape.push_back (resultType.getDimSize (i));
2353
- // Read may be out-of-bounds because the result size could be larger
2354
- // than the source size.
2355
- readInBounds.push_back (false );
2356
- // Write is out-of-bounds if low padding > 0.
2357
- writeInBounds.push_back (
2358
- getConstantIntValue (padOp.getMixedLowPad ()[i]) ==
2359
- static_cast <int64_t >(0 ));
2360
- } else {
2361
- // Neither source nor result dim of padOp is static. Cannot vectorize
2362
- // the copy.
2363
- return failure ();
2364
- }
2365
- }
2366
- auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2367
-
2368
- // Generate TransferReadOp.
2369
- SmallVector<Value> readIndices (
2370
- vecType.getRank (),
2371
- rewriter.create <arith::ConstantIndexOp>(padOp.getLoc (), 0 ));
2372
- auto read = rewriter.create <vector::TransferReadOp>(
2373
- padOp.getLoc (), vecType, padOp.getSource (), readIndices, padValue,
2374
- ArrayRef<bool >{readInBounds});
2375
-
2376
- // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2377
- // entire tensor, write directly to the FillOp's operand.
2378
- if (llvm::equal (vecShape, resultType.getShape ()) &&
2379
- llvm::all_of (writeInBounds, [](bool b) { return b; }))
2380
- if (auto fill = dest.getDefiningOp <FillOp>())
2381
- dest = fill.output ();
2382
-
2383
- // Generate TransferWriteOp.
2384
- auto writeIndices =
2385
- ofrToIndexValues (rewriter, padOp.getLoc (), padOp.getMixedLowPad ());
2386
- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2387
- padOp, read, dest, writeIndices, ArrayRef<bool >{writeInBounds});
2388
-
2389
- return success ();
2390
- }
2391
- };
2392
-
2393
2284
// / Base pattern for rewriting tensor::PadOps whose result is consumed by a
2394
2285
// / given operation type OpTy.
2395
2286
template <typename OpTy>
@@ -2623,6 +2514,163 @@ struct PadOpVectorizationWithTransferWritePattern
2623
2514
}
2624
2515
};
2625
2516
2517
+ // / Returns the effective Pad value for the input op, provided it's a scalar.
2518
+ // /
2519
+ // / Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2520
+ // / this Op performs padding, retrieve the padding value provided that it's
2521
+ // / a scalar and static/fixed for all the padded values. Returns an empty value
2522
+ // / otherwise.
2523
+ static Value getStaticPadVal (Operation *op) {
2524
+ if (!op)
2525
+ return {};
2526
+
2527
+ // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2528
+ // being broadcast, provided that it's a scalar.
2529
+ if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2530
+ auto source = bcast.getSource ();
2531
+ if (llvm::dyn_cast<VectorType>(source.getType ()))
2532
+ return {};
2533
+
2534
+ return source;
2535
+ }
2536
+
2537
+ // 2. linalg.fill - use the scalar input value that used to fill the output
2538
+ // tensor.
2539
+ if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2540
+ return fill.getInputs ()[0 ];
2541
+ }
2542
+
2543
+ // 3. tensor.generateOp - can't guarantee the value is fixed without
2544
+ // analysing, bail out.
2545
+ if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2546
+ return {};
2547
+ }
2548
+
2549
+ // 4. vector.transfer_write - inspect the input vector that's written from. If
2550
+ // if contains a single value that has been broadcast (e.g. via
2551
+ // vector.broadcast), extract it, fail otherwise.
2552
+ if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2553
+ return getStaticPadVal (xferWrite.getVector ().getDefiningOp ());
2554
+
2555
+ // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2556
+ // than the input tensor, then, provided it's constant, we'll extract the
2557
+ // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2558
+ // TODO: Clarify the semantics when the input tensor is larger than the
2559
+ // destination.
2560
+ if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2561
+ return getStaticPadVal (slice.getDest ().getDefiningOp ());
2562
+
2563
+ return {};
2564
+ }
2565
+
2566
+ // / Rewrite tensor.insert.slice as a vector.transfer_read +
2567
+ // / vector.transfer_write pair. The vector size is inferred from the static
2568
+ // / dims in the input and output tensors. If a dim is dynamic in both the input
2569
+ // / and output tensors, bails out.
2570
+ // /
2571
+ // / Before:
2572
+ // / !t_in_type = tensor<1x2x3xf32>
2573
+ // / !t_out_type = tensor<9x8x7x1x2x3xf32>
2574
+ // / !v_type = vector<1x2x3xf32>
2575
+ // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2576
+ // / into !t_out_type
2577
+ // / After:
2578
+ // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2579
+ // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2580
+ // /
2581
+ // / TODO: Support masking
2582
+ struct InsertSliceVectorizePattern
2583
+ : public OpRewritePattern<tensor::InsertSliceOp> {
2584
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2585
+
2586
+ LogicalResult matchAndRewrite (tensor::InsertSliceOp sliceOp,
2587
+ PatternRewriter &rewriter) const final {
2588
+ auto sourceType = sliceOp.getSource ().getType ();
2589
+ if (!VectorType::isValidElementType (sourceType.getElementType ()))
2590
+ return failure ();
2591
+
2592
+ auto resultType = sliceOp.getResultType ();
2593
+
2594
+ // 1. Get the pad value.
2595
+ // TransferReadOp requires a scalar padding value. Note that:
2596
+ // * for in-bounds access, the value is actually irrelevant.
2597
+ // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2598
+ // 1. The source shape is static (output vector sizes would be based on
2599
+ // the source shape and hence all memory accesses would be in-bounds),
2600
+ // 2. Masking is used (output vector sizes would be user-provided, in which
2601
+ // case it is assumed that all memory accesses are in-bounds). This
2602
+ // remains a TODO.
2603
+ //
2604
+ // When the value is not known and not needed, use 0. Otherwise, bail out.
2605
+ Value padValue = getStaticPadVal (sliceOp);
2606
+ bool isOutOfBoundsRead = !sourceType.hasStaticShape ();
2607
+
2608
+ if (!padValue && isOutOfBoundsRead) {
2609
+ LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2610
+ return failure ();
2611
+ }
2612
+
2613
+ if (!padValue) {
2614
+ auto elemType = sourceType.getElementType ();
2615
+ padValue = rewriter.create <arith::ConstantOp>(
2616
+ sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2617
+ }
2618
+
2619
+ // 2. Get the vector shape and in-bounds attributes
2620
+ SmallVector<int64_t > vecShape;
2621
+ SmallVector<bool > readInBounds;
2622
+ SmallVector<bool > writeInBounds;
2623
+ size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2624
+ for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2625
+ if (!sourceType.isDynamicDim (i)) {
2626
+ vecShape.push_back (sourceType.getDimSize (i));
2627
+ // Source shape is statically known: Neither read nor write are
2628
+ // out-of-bounds.
2629
+ readInBounds.push_back (true );
2630
+ writeInBounds.push_back (true );
2631
+ } else if (!resultType.isDynamicDim (i)) {
2632
+ // Source shape is not statically known, but result shape is.
2633
+ // Vectorize with size of result shape. This may be larger than the
2634
+ // source size.
2635
+ // FIXME: Using rankDiff implies that the source tensor is inserted at
2636
+ // the end of the destination tensor. However, that's not required.
2637
+ vecShape.push_back (resultType.getDimSize (rankDiff + i));
2638
+ // Read may be out-of-bounds because the result size could be larger
2639
+ // than the source size.
2640
+ readInBounds.push_back (false );
2641
+ // Write will in-bounds provided that the corresponding write idx is 0.
2642
+ // To keep this logic simple, conservatively mark as out-of-bounds.
2643
+ writeInBounds.push_back (false );
2644
+ } else {
2645
+ // Neither source nor result dim of padOp is static. Cannot vectorize
2646
+ // the copy.
2647
+ // TODO: Add support for masking
2648
+ return failure ();
2649
+ }
2650
+ }
2651
+ auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2652
+
2653
+ // 3. Generate TransferReadOp.
2654
+ SmallVector<Value> readIndices (
2655
+ vecType.getRank (),
2656
+ rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2657
+ auto read = rewriter.create <vector::TransferReadOp>(
2658
+ sliceOp.getLoc (), vecType, sliceOp.getSource (), readIndices, padValue,
2659
+ ArrayRef<bool >{readInBounds});
2660
+
2661
+ // 4. Generate TransferWriteOp.
2662
+ auto writeIndices = getValueOrCreateConstantIndexOp (
2663
+ rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2664
+
2665
+ // 5. Finalize
2666
+ rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2667
+ sliceOp, read, sliceOp.getDest (), writeIndices,
2668
+ ArrayRef<bool >{writeInBounds});
2669
+
2670
+ return success ();
2671
+ }
2672
+ };
2673
+
2626
2674
// / Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2627
2675
// / ```
2628
2676
// / %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
@@ -2699,8 +2747,8 @@ struct PadOpVectorizationWithInsertSlicePattern
2699
2747
// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2700
2748
// specified offsets. Write is fully in-bounds because a InsertSliceOp's
2701
2749
// source must fit into the destination at the specified offsets.
2702
- auto writeIndices =
2703
- ofrToIndexValues ( rewriter, padOp.getLoc (), insertOp.getMixedOffsets ());
2750
+ auto writeIndices = getValueOrCreateConstantIndexOp (
2751
+ rewriter, padOp.getLoc (), insertOp.getMixedOffsets ());
2704
2752
SmallVector<bool > inBounds (vecRank, true );
2705
2753
rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2706
2754
insertOp, read, insertOp.getDest (), writeIndices,
@@ -2710,13 +2758,18 @@ struct PadOpVectorizationWithInsertSlicePattern
2710
2758
}
2711
2759
};
2712
2760
2761
+ void mlir::linalg::populateInsertSliceVectorizationPatterns (
2762
+ RewritePatternSet &patterns) {
2763
+ patterns.add <InsertSliceVectorizePattern>(patterns.getContext ());
2764
+ }
2765
+
2713
2766
void mlir::linalg::populatePadOpVectorizationPatterns (
2714
2767
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2715
2768
// TODO: The following pattern implements "decomposition" and
2716
2769
// optional "vectorization". Seperate "decomposition" into a sepereate
2717
2770
// pre-processing pattern group.
2718
- patterns.add <GenericPadOpVectorizationPattern >(patterns.getContext (),
2719
- baseBenefit);
2771
+ patterns.add <GeneralizePadOpPattern >(patterns.getContext (), baseBenefit);
2772
+
2720
2773
// Try these specialized patterns first before resorting to the generic one.
2721
2774
patterns.add <PadOpVectorizationWithTransferReadPattern,
2722
2775
PadOpVectorizationWithTransferWritePattern,
0 commit comments