@@ -435,188 +435,6 @@ static void calculateTileOffsetsAndSizes(
435
435
}
436
436
}
437
437
438
- // / Returns a vector of bools representing if, for each axis, `op` can be tiled
439
- // / without incurring in a race condition and thus it is thread-safe to do the
440
- // / tiling. This is checked by iterating over numThreads and ensuring that the
441
- // / corresponding iterator type is "parallel". If it is not, then we know that
442
- // / such dimension is unsafe to tile.
443
- SmallVector<bool > safeToTileToForall (mlir::MLIRContext *ctx, LinalgOp linalgOp,
444
- ArrayRef<OpFoldResult> numThreads) {
445
- auto iterators = linalgOp.getIteratorTypesArray ();
446
- SmallVector<bool > safeToTile (numThreads.size (), true );
447
-
448
- for (unsigned i = 0 , e = numThreads.size (); i != e; i++) {
449
- if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
450
- if (cast<IntegerAttr>(attr).getValue ().getSExtValue () > 1 ) {
451
- safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
452
- }
453
- } else {
454
- safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
455
- }
456
- }
457
- return safeToTile;
458
- }
459
-
460
- // / Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
461
- // / tiling is specified by the number of tiles/threads `numThreads` and the
462
- // / optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
463
- // / not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
464
- // / numThreads[i])`. If non-empty, the `mapping` is added as an
465
- // / attribute to the resulting `scf.forall`. A zero tile sizes indicate
466
- // / that the dimension is not tiled, and can be thought of as tiling by the full
467
- // / size of data.
468
- // / It is the user's responsibility to ensure that `numThreads` is a valid
469
- // / tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
470
- // / Linalg case). If the dimension is not parallelizable, a warning is issued to
471
- // / notify the user that the generated code is not safe to parallelize. If
472
- // / `omitTileOffsetBoundsCheck` is true, then the function will assume that
473
- // / `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
474
- static FailureOr<ForallTilingResult> tileToForallOpImpl (
475
- RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
476
- std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
477
- std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
478
- Location loc = op->getLoc ();
479
- OpBuilder::InsertionGuard g (b);
480
-
481
- SmallVector<Range> loopRanges = op.getIterationDomain (b);
482
- if (loopRanges.empty ())
483
- return op->emitOpError (" expected non-empty loop ranges" );
484
- auto hasStrideOne = [](Range r) { return !isConstantIntValue (r.stride , 1 ); };
485
- if (llvm::any_of (loopRanges, hasStrideOne))
486
- return op->emitOpError (" only stride-1 supported atm" );
487
-
488
- // Gather destination tensors.
489
- SmallVector<Value> dest;
490
- if (failed (tensor::getOrCreateDestinations (b, loc, op, dest)))
491
- return op->emitOpError (" failed to get destination tensors" );
492
-
493
- SmallVector<OpFoldResult> nonZeroNumThreads =
494
- llvm::to_vector (llvm::make_filter_range (numThreads, [](OpFoldResult ofr) {
495
- return !isConstantIntValue (ofr, 0 );
496
- }));
497
- SmallVector<Value> materializedNonZeroNumThreads =
498
- llvm::to_vector (llvm::map_range (nonZeroNumThreads, [&](OpFoldResult ofr) {
499
- return getValueOrCreateConstantIndexOp (b, loc, ofr);
500
- }));
501
-
502
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation ());
503
- if (linalgOp) {
504
- // Check if tiling is thread safe and print a warning if not.
505
- SmallVector<bool > tilingSafety =
506
- safeToTileToForall (b.getContext (), linalgOp, numThreads);
507
- for (size_t i = 0 ; i < tilingSafety.size (); i++)
508
- if (!tilingSafety[i])
509
- op.emitWarning () << " tiling is not thread safe at axis #" << i;
510
- }
511
-
512
- // 1. Create the ForallOp. We don't use the lambda body-builder
513
- // version because we require the use of RewriterBase in the body, so we
514
- // manually move the insertion point to the body below.
515
- scf::ForallOp forallOp = b.create <scf::ForallOp>(
516
- loc, getAsOpFoldResult ((materializedNonZeroNumThreads)), dest, mapping);
517
-
518
- // 2. Fill out the ForallOp body.
519
- SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
520
- calculateTileOffsetsAndSizes (b, loc, forallOp, numThreads, loopRanges,
521
- omitTileOffsetBoundsCheck, nominalTileSizes,
522
- tiledOffsets, tiledSizes);
523
-
524
- // 3. Clone the tileable op and update its destination operands to use the
525
- // output bbArgs of the ForallOp.
526
- ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs ();
527
- Operation *tiledOp = nullptr ;
528
- SmallVector<Value> tiledValues;
529
- {
530
- // 3.a. RAII guard, inserting within forallOp, before terminator.
531
- OpBuilder::InsertionGuard g (b);
532
- b.setInsertionPoint (forallOp.getTerminator ());
533
- Operation *clonedOp = b.clone (*op.getOperation ());
534
- auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
535
- if (destinationStyleOp) {
536
- for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable ()) {
537
- // Swap tensor inits with the corresponding block argument of the
538
- // scf.forall op. Memref inits remain as is.
539
- if (isa<TensorType>(outOperand.get ().getType ())) {
540
- auto *it = llvm::find (dest, outOperand.get ());
541
- assert (it != dest.end () && " could not find destination tensor" );
542
- unsigned destNum = std::distance (dest.begin (), it);
543
- outOperand.set (destBbArgs[destNum]);
544
- }
545
- }
546
- }
547
-
548
- // 4. Tile the cloned op and delete the clone.
549
- FailureOr<TilingResult> tilingResult =
550
- cast<TilingInterface>(clonedOp).getTiledImplementation (b, tiledOffsets,
551
- tiledSizes);
552
- if (failed (tilingResult))
553
- return clonedOp->emitError (" Failed to tile op: " );
554
- if (tilingResult->tiledOps .size () != 1 ) {
555
- return clonedOp->emitError (" expected a single produced tiled op, got " )
556
- << tilingResult->tiledOps .size ();
557
- }
558
-
559
- b.eraseOp (clonedOp);
560
- tiledOp = tilingResult->tiledOps .front ();
561
- tiledValues = tilingResult->tiledValues ;
562
- }
563
-
564
- // 5. Parallel insert back into the result tensor.
565
- for (auto it : llvm::zip (llvm::seq (unsigned (0 ), unsigned (dest.size ())),
566
- tiledValues, destBbArgs)) {
567
- // 5.a. Partial subset information is inserted just before the terminator.
568
- OpBuilder::InsertionGuard g (b);
569
- b.setInsertionPoint (forallOp.getTerminator ());
570
-
571
- SmallVector<OpFoldResult> resultOffsets, resultSizes;
572
- if (failed (op.getResultTilePosition (b, std::get<0 >(it), tiledOffsets,
573
- tiledSizes, resultOffsets,
574
- resultSizes)))
575
- return op->emitOpError (" output offsets couldn't be calculated" );
576
- SmallVector<OpFoldResult> strides (resultSizes.size (), b.getIndexAttr (1 ));
577
-
578
- // 5.b. Parallel insertions are inserted at the end of the combining
579
- // terminator.
580
- b.setInsertionPointToEnd (forallOp.getTerminator ().getBody ());
581
- b.create <tensor::ParallelInsertSliceOp>(loc, std::get<1 >(it),
582
- std::get<2 >(it), resultOffsets,
583
- resultSizes, strides);
584
- }
585
- return ForallTilingResult{forallOp, tiledOp};
586
- }
587
-
588
- FailureOr<ForallTilingResult>
589
- linalg::tileToForallOp (RewriterBase &b, TilingInterface op,
590
- ArrayRef<OpFoldResult> numThreads,
591
- std::optional<ArrayAttr> mapping) {
592
- return tileToForallOpImpl (b, op, numThreads,
593
- /* nominalTileSizes=*/ std::nullopt, mapping,
594
- /* omitTileOffsetBoundsCheck=*/ false );
595
- }
596
-
597
- FailureOr<ForallTilingResult>
598
- linalg::tileToForallOpUsingTileSizes (RewriterBase &b, TilingInterface op,
599
- ArrayRef<OpFoldResult> tileSizes,
600
- std::optional<ArrayAttr> mapping) {
601
- SmallVector<Range> loopRanges = op.getIterationDomain (b);
602
- unsigned nLoops = loopRanges.size ();
603
- SmallVector<OpFoldResult> numThreads;
604
- numThreads.reserve (nLoops);
605
- AffineExpr s0, s1;
606
- bindSymbols (b.getContext (), s0, s1);
607
- AffineExpr divExpr = s0.ceilDiv (s1);
608
- for (const auto &it : llvm::zip (tileSizes, loopRanges)) {
609
- OpFoldResult numTiles = std::get<0 >(it);
610
- if (!isConstantIntValue (numTiles, 0 ))
611
- numTiles = makeComposedFoldedAffineApply (
612
- b, op.getLoc (), divExpr, {std::get<1 >(it).size , std::get<0 >(it)});
613
- numThreads.push_back (numTiles);
614
- }
615
- return tileToForallOpImpl (b, op, numThreads,
616
- /* nominalTileSizes=*/ tileSizes, mapping,
617
- /* omitTileOffsetBoundsCheck=*/ true );
618
- }
619
-
620
438
template <typename LoopTy>
621
439
static FailureOr<TiledLinalgOp>
622
440
tileLinalgOpImpl (RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
0 commit comments