@@ -767,8 +767,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
767
767
768
768
// 3. Build the offsets, sizes and steps for the tile and distributed loops.
769
769
SmallVector<OpFoldResult> lbs, ubs, steps;
770
- for (auto [index, tileSize, loopRange] :
771
- llvm::enumerate (tileSizeVector, loopRanges)) {
770
+ for (auto [tileSize, loopRange] : llvm::zip (tileSizeVector, loopRanges)) {
772
771
if (isConstantIntValue (tileSize, 0 ))
773
772
continue ;
774
773
lbs.push_back (loopRange.offset );
@@ -781,7 +780,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
781
780
if (failed (tensor::getOrCreateDestinations (rewriter, loc, op, dest)))
782
781
return op->emitOpError (" failed to get destination tensors" );
783
782
784
- // 5. Build the device mapping attribute;
783
+ // 5. Build the device mapping attribute.
785
784
std::optional<ArrayAttr> mappingAttr;
786
785
if (!options.mappingVector .empty ()) {
787
786
mappingAttr = rewriter.getArrayAttr (ArrayRef (options.mappingVector ));
@@ -796,13 +795,10 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
796
795
// 7. Get the tile offset and sizes.
797
796
rewriter.setInsertionPoint (forallOp.getTerminator ());
798
797
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
799
- tiledOffsets.reserve (loopRanges.size ());
800
- tiledSizes.reserve (loopRanges.size ());
801
798
ValueRange ivs = forallOp.getInductionVars ();
802
799
{
803
800
int materializedLoopNum = 0 ;
804
- for (auto [index, tileSize, loopRange] :
805
- llvm::enumerate (tileSizeVector, loopRanges)) {
801
+ for (auto [tileSize, loopRange] : llvm::zip (tileSizeVector, loopRanges)) {
806
802
if (isConstantIntValue (tileSize, 0 )) {
807
803
tiledOffsets.push_back (loopRange.offset );
808
804
tiledSizes.push_back (loopRange.size );
@@ -816,15 +812,15 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
816
812
}
817
813
818
814
// 8. Tile the operation. Clone the operation to allow fix up of destination
819
- // operands
815
+ // operands.
820
816
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments ();
821
817
Operation *clonedOp =
822
818
cloneOpAndUpdateDestinationArgs (rewriter, op, destBbArgs);
823
819
FailureOr<TilingResult> tilingResult =
824
820
cast<TilingInterface>(clonedOp).getTiledImplementation (
825
821
rewriter, tiledOffsets, tiledSizes);
826
822
if (failed (tilingResult))
827
- return clonedOp->emitError (" Failed to tile op: " );
823
+ return clonedOp->emitError (" failed to tile op: " );
828
824
rewriter.eraseOp (clonedOp);
829
825
830
826
// 9. Parallel insert back into the result tensor.
@@ -836,24 +832,25 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
836
832
SmallVector<OpFoldResult> resultOffsets, resultSizes;
837
833
if (failed (op.getResultTilePosition (rewriter, index, tiledOffsets,
838
834
tiledSizes, resultOffsets,
839
- resultSizes)))
835
+ resultSizes))) {
840
836
return op->emitOpError (" output offsets couldn't be calculated" );
837
+ }
838
+
841
839
SmallVector<OpFoldResult> strides (resultSizes.size (),
842
840
rewriter.getIndexAttr (1 ));
843
-
844
- // 5.b. Parallel insertions are inserted at the end of the combining
841
+ // 9.b. Parallel insertions are inserted at the end of the combining
845
842
// terminator.
846
843
rewriter.setInsertionPointToEnd (forallOp.getTerminator ().getBody ());
847
844
rewriter.create <tensor::ParallelInsertSliceOp>(
848
845
loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
849
846
}
850
847
851
- // 10. Return the tiling result;
848
+ // 10. Return the tiling result.
852
849
return scf::SCFTilingResult{
853
850
tilingResult->tiledOps ,
854
851
{forallOp.getOperation ()},
855
- llvm::to_vector ( llvm::map_range (forallOp.getResults (),
856
- [](auto val) -> Value { return val; }) )};
852
+ llvm::map_to_vector (forallOp.getResults (),
853
+ [](auto val) -> Value { return val; })};
857
854
}
858
855
859
856
// ===----------------------------------------------------------------------===//
0 commit comments