@@ -122,6 +122,24 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
122
122
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
123
123
}
124
124
125
+ // / Clones the operation and updates the destination if the operation
126
+ // / implements the `DestinationStyleOpInterface`.
127
+ static Operation *cloneOpAndUpdateDestinationArgs (RewriterBase &rewriter,
128
+ Operation *op,
129
+ ValueRange newDestArgs) {
130
+ Operation *clonedOp = rewriter.clone (*op);
131
+ if (auto destinationStyleOp =
132
+ dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
133
+ // Note that this is assuming that
134
+ auto [start, end] = destinationStyleOp.getDpsInitsPositionRange ();
135
+ assert ((end - start == newDestArgs.size ()) &&
136
+ " expected as many new destination args as number of inits of the "
137
+ " operation" );
138
+ clonedOp->setOperands (start, end - start, newDestArgs);
139
+ }
140
+ return clonedOp;
141
+ }
142
+
125
143
// / Generate an empty loop nest that represents the tiled loop nest shell.
126
144
// / - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
127
145
// / - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
@@ -728,6 +746,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
728
746
getAsOperations (forLoops), replacements};
729
747
}
730
748
749
+ // ===----------------------------------------------------------------------===//
750
+ // tileUsingSCFForAllOp implementation.
751
+ // ===----------------------------------------------------------------------===//
752
+
753
+ FailureOr<scf::SCFTilingResult>
754
+ mlir::scf::tileUsingSCFForallOp (RewriterBase &rewriter, TilingInterface op,
755
+ const scf::SCFTilingOptions &options) {
756
+ Location loc = op->getLoc ();
757
+ OpBuilder::InsertionGuard g (rewriter);
758
+
759
+ // 1. Get the range of loops that are represented by the operation.
760
+ SmallVector<Range> loopRanges = op.getIterationDomain (rewriter);
761
+ if (loopRanges.empty ())
762
+ return op->emitOpError (" expected non-empty loop ranges" );
763
+ auto hasStrideOne = [](Range r) { return !isConstantIntValue (r.stride , 1 ); };
764
+ if (llvm::any_of (loopRanges, hasStrideOne))
765
+ return op->emitOpError (" only stride-1 supported atm" );
766
+
767
+ // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
768
+ // To make it easier, pad the tile sizes to loopRanges.size with value 0.
769
+ SmallVector<OpFoldResult> tileSizeVector =
770
+ options.tileSizeComputationFunction (rewriter, op);
771
+ tileSizeVector.resize (loopRanges.size (), rewriter.getIndexAttr (0 ));
772
+
773
+ // 3. Build the offsets, sizes and steps for the tile and distributed loops.
774
+ SmallVector<OpFoldResult> lbs, ubs, steps;
775
+ for (auto [index, tileSize, loopRange] :
776
+ llvm::enumerate (tileSizeVector, loopRanges)) {
777
+ if (isConstantIntValue (tileSize, 0 ))
778
+ continue ;
779
+ lbs.push_back (loopRange.offset );
780
+ ubs.push_back (loopRange.size );
781
+ steps.push_back (tileSize);
782
+ }
783
+
784
+ // 4. Gather destination tensors.
785
+ SmallVector<Value> dest;
786
+ if (failed (tensor::getOrCreateDestinations (rewriter, loc, op, dest)))
787
+ return op->emitOpError (" failed to get destination tensors" );
788
+
789
+ // 5. Build the device mapping attribute;
790
+ std::optional<ArrayAttr> mappingAttr;
791
+ if (!options.mappingVector .empty ()) {
792
+ mappingAttr = rewriter.getArrayAttr (ArrayRef (options.mappingVector ));
793
+ }
794
+
795
+ // 6. Create the ForallOp. We don't use the lambda body-builder
796
+ // version because we require the use of RewriterBase in the body, so we
797
+ // manually move the insertion point to the body below.
798
+ auto forallOp =
799
+ rewriter.create <scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
800
+
801
+ // 7. Get the tile offset and sizes.
802
+ rewriter.setInsertionPoint (forallOp.getTerminator ());
803
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
804
+ tiledOffsets.reserve (loopRanges.size ());
805
+ tiledSizes.reserve (loopRanges.size ());
806
+ ValueRange ivs = forallOp.getInductionVars ();
807
+ {
808
+ int materializedLoopNum = 0 ;
809
+ for (auto [index, tileSize, loopRange] :
810
+ llvm::enumerate (tileSizeVector, loopRanges)) {
811
+ if (isConstantIntValue (tileSize, 0 )) {
812
+ tiledOffsets.push_back (loopRange.offset );
813
+ tiledSizes.push_back (loopRange.size );
814
+ continue ;
815
+ }
816
+ Value iv = ivs[materializedLoopNum++];
817
+ tiledOffsets.push_back (iv);
818
+ tiledSizes.push_back (
819
+ getBoundedTileSize (rewriter, loc, loopRange, iv, tileSize));
820
+ }
821
+ }
822
+
823
+ // 8. Tile the operation. Clone the operation to allow fix up of destination
824
+ // operands
825
+ ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments ();
826
+ Operation *clonedOp =
827
+ cloneOpAndUpdateDestinationArgs (rewriter, op, destBbArgs);
828
+ FailureOr<TilingResult> tilingResult =
829
+ cast<TilingInterface>(clonedOp).getTiledImplementation (
830
+ rewriter, tiledOffsets, tiledSizes);
831
+ if (failed (tilingResult))
832
+ return clonedOp->emitError (" Failed to tile op: " );
833
+ rewriter.eraseOp (clonedOp);
834
+
835
+ // 9. Parallel insert back into the result tensor.
836
+ for (auto [index, tiledValue, destBBArg] :
837
+ llvm::enumerate (tilingResult->tiledValues , destBbArgs)) {
838
+ // 9.a. Partial subset information is inserted just before the terminator.
839
+ rewriter.setInsertionPoint (forallOp.getTerminator ());
840
+
841
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
842
+ if (failed (op.getResultTilePosition (rewriter, index, tiledOffsets,
843
+ tiledSizes, resultOffsets,
844
+ resultSizes)))
845
+ return op->emitOpError (" output offsets couldn't be calculated" );
846
+ SmallVector<OpFoldResult> strides (resultSizes.size (),
847
+ rewriter.getIndexAttr (1 ));
848
+
849
+ // 5.b. Parallel insertions are inserted at the end of the combining
850
+ // terminator.
851
+ rewriter.setInsertionPointToEnd (forallOp.getTerminator ().getBody ());
852
+ rewriter.create <tensor::ParallelInsertSliceOp>(
853
+ loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
854
+ }
855
+
856
+ // 10. Return the tiling result;
857
+ return scf::SCFTilingResult{
858
+ tilingResult->tiledOps ,
859
+ {forallOp.getOperation ()},
860
+ llvm::to_vector (llvm::map_range (forallOp.getResults (),
861
+ [](auto val) -> Value { return val; }))};
862
+ }
863
+
731
864
// ===----------------------------------------------------------------------===//
732
865
// lowerToLoopsUsingSCFForOp implementation.
733
866
// ===----------------------------------------------------------------------===//
0 commit comments