@@ -910,61 +910,96 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
910
910
unsigned numTargetOuts = target.getNumResults ();
911
911
unsigned numSourceOuts = source.getNumResults ();
912
912
913
- OperandRange targetOuts = target.getOutputs ();
914
- OperandRange sourceOuts = source.getOutputs ();
915
-
916
913
// Create fused shared_outs.
917
914
SmallVector<Value> fusedOuts;
918
- fusedOuts.reserve (numTargetOuts + numSourceOuts);
919
- fusedOuts.append (targetOuts.begin (), targetOuts.end ());
920
- fusedOuts.append (sourceOuts.begin (), sourceOuts.end ());
915
+ llvm::append_range (fusedOuts, target.getOutputs ());
916
+ llvm::append_range (fusedOuts, source.getOutputs ());
921
917
922
- // Create a new scf:: forall op after the source loop.
918
+ // Create a new scf. forall op after the source loop.
923
919
rewriter.setInsertionPointAfter (source);
924
920
scf::ForallOp fusedLoop = rewriter.create <scf::ForallOp>(
925
921
source.getLoc (), source.getMixedLowerBound (), source.getMixedUpperBound (),
926
922
source.getMixedStep (), fusedOuts, source.getMapping ());
927
923
928
924
// Map control operands.
929
- IRMapping fusedMapping ;
930
- fusedMapping .map (target.getInductionVars (), fusedLoop.getInductionVars ());
931
- fusedMapping .map (source.getInductionVars (), fusedLoop.getInductionVars ());
925
+ IRMapping mapping ;
926
+ mapping .map (target.getInductionVars (), fusedLoop.getInductionVars ());
927
+ mapping .map (source.getInductionVars (), fusedLoop.getInductionVars ());
932
928
933
929
// Map shared outs.
934
- fusedMapping.map (target.getRegionIterArgs (),
935
- fusedLoop.getRegionIterArgs ().slice (0 , numTargetOuts));
936
- fusedMapping.map (
937
- source.getRegionIterArgs (),
938
- fusedLoop.getRegionIterArgs ().slice (numTargetOuts, numSourceOuts));
930
+ mapping.map (target.getRegionIterArgs (),
931
+ fusedLoop.getRegionIterArgs ().take_front (numTargetOuts));
932
+ mapping.map (source.getRegionIterArgs (),
933
+ fusedLoop.getRegionIterArgs ().take_back (numSourceOuts));
939
934
940
935
// Append everything except the terminator into the fused operation.
941
936
rewriter.setInsertionPointToStart (fusedLoop.getBody ());
942
937
for (Operation &op : target.getBody ()->without_terminator ())
943
- rewriter.clone (op, fusedMapping );
938
+ rewriter.clone (op, mapping );
944
939
for (Operation &op : source.getBody ()->without_terminator ())
945
- rewriter.clone (op, fusedMapping );
940
+ rewriter.clone (op, mapping );
946
941
947
942
// Fuse the old terminator in_parallel ops into the new one.
948
943
scf::InParallelOp targetTerm = target.getTerminator ();
949
944
scf::InParallelOp sourceTerm = source.getTerminator ();
950
945
scf::InParallelOp fusedTerm = fusedLoop.getTerminator ();
951
-
952
946
rewriter.setInsertionPointToStart (fusedTerm.getBody ());
953
947
for (Operation &op : targetTerm.getYieldingOps ())
954
- rewriter.clone (op, fusedMapping );
948
+ rewriter.clone (op, mapping );
955
949
for (Operation &op : sourceTerm.getYieldingOps ())
956
- rewriter.clone (op, fusedMapping);
957
-
958
- // Replace all uses of the old loops with the fused loop.
959
- rewriter.replaceAllUsesWith (target.getResults (),
960
- fusedLoop.getResults ().slice (0 , numTargetOuts));
961
- rewriter.replaceAllUsesWith (
962
- source.getResults (),
963
- fusedLoop.getResults ().slice (numTargetOuts, numSourceOuts));
964
-
965
- // Erase the old loops.
966
- rewriter.eraseOp (target);
967
- rewriter.eraseOp (source);
950
+ rewriter.clone (op, mapping);
951
+
952
+ // Replace old loops by substituting their uses by results of the fused loop.
953
+ rewriter.replaceOp (target, fusedLoop.getResults ().take_front (numTargetOuts));
954
+ rewriter.replaceOp (source, fusedLoop.getResults ().take_back (numSourceOuts));
955
+
956
+ return fusedLoop;
957
+ }
958
+
959
+ scf::ForOp mlir::fuseIndependentSiblingForLoops (scf::ForOp target,
960
+ scf::ForOp source,
961
+ RewriterBase &rewriter) {
962
+ unsigned numTargetOuts = target.getNumResults ();
963
+ unsigned numSourceOuts = source.getNumResults ();
964
+
965
+ // Create fused init_args, with target's init_args before source's init_args.
966
+ SmallVector<Value> fusedInitArgs;
967
+ llvm::append_range (fusedInitArgs, target.getInitArgs ());
968
+ llvm::append_range (fusedInitArgs, source.getInitArgs ());
969
+
970
+ // Create a new scf.for op after the source loop.
971
+ rewriter.setInsertionPointAfter (source);
972
+ scf::ForOp fusedLoop = rewriter.create <scf::ForOp>(
973
+ source.getLoc (), source.getLowerBound (), source.getUpperBound (),
974
+ source.getStep (), fusedInitArgs);
975
+
976
+ // Map original induction variables and operands to those of the fused loop.
977
+ IRMapping mapping;
978
+ mapping.map (target.getInductionVar (), fusedLoop.getInductionVar ());
979
+ mapping.map (target.getRegionIterArgs (),
980
+ fusedLoop.getRegionIterArgs ().take_front (numTargetOuts));
981
+ mapping.map (source.getInductionVar (), fusedLoop.getInductionVar ());
982
+ mapping.map (source.getRegionIterArgs (),
983
+ fusedLoop.getRegionIterArgs ().take_back (numSourceOuts));
984
+
985
+ // Merge target's body into the new (fused) for loop and then source's body.
986
+ rewriter.setInsertionPointToStart (fusedLoop.getBody ());
987
+ for (Operation &op : target.getBody ()->without_terminator ())
988
+ rewriter.clone (op, mapping);
989
+ for (Operation &op : source.getBody ()->without_terminator ())
990
+ rewriter.clone (op, mapping);
991
+
992
+ // Build fused yield results by appropriately mapping original yield operands.
993
+ SmallVector<Value> yieldResults;
994
+ for (Value operand : target.getBody ()->getTerminator ()->getOperands ())
995
+ yieldResults.push_back (mapping.lookupOrDefault (operand));
996
+ for (Value operand : source.getBody ()->getTerminator ()->getOperands ())
997
+ yieldResults.push_back (mapping.lookupOrDefault (operand));
998
+ rewriter.create <scf::YieldOp>(source.getLoc (), yieldResults);
999
+
1000
+ // Replace old loops by substituting their uses by results of the fused loop.
1001
+ rewriter.replaceOp (target, fusedLoop.getResults ().take_front (numTargetOuts));
1002
+ rewriter.replaceOp (source, fusedLoop.getResults ().take_back (numSourceOuts));
968
1003
969
1004
return fusedLoop;
970
1005
}
0 commit comments