|
9 | 9 | #include "detail/config.hpp"
|
10 | 10 | #include <detail/context_impl.hpp>
|
11 | 11 | #include <detail/event_impl.hpp>
|
| 12 | +#include <sstream> |
12 | 13 | #include <sycl/feature_test.hpp>
|
13 | 14 | #if SYCL_EXT_CODEPLAY_KERNEL_FUSION
|
14 | 15 | #include <detail/jit_compiler.hpp>
|
@@ -949,66 +950,75 @@ Scheduler::GraphBuildResult Scheduler::GraphBuilder::addCG(
|
949 | 950 | if (!NewCmd)
|
950 | 951 | throw runtime_error("Out of host memory", PI_ERROR_OUT_OF_HOST_MEMORY);
|
951 | 952 |
|
952 |
| - // Host tasks cannot participate in fusion. They take the regular route. If |
953 |
| - // they create any requirement or event dependency on any of the kernels in |
954 |
| - // the fusion list, this will lead to cancellation of the fusion in the |
955 |
| - // GraphProcessor. |
| 953 | + // Only device kernel command groups can participate in fusion. Otherwise, |
| 954 | + // command groups take the regular route. If they create any requirement or |
| 955 | + // event dependency on any of the kernels in the fusion list, this will lead |
| 956 | + // to cancellation of the fusion in the GraphProcessor. |
956 | 957 | auto QUniqueID = std::hash<sycl::detail::queue_impl *>()(Queue.get());
|
957 |
| - if (isInFusionMode(QUniqueID) && !NewCmd->isHostTask()) { |
958 |
| - auto *FusionCmd = findFusionList(QUniqueID)->second.get(); |
959 |
| - |
960 |
| - bool dependsOnFusion = false; |
961 |
| - for (auto Ev = Events.begin(); Ev != Events.end();) { |
962 |
| - auto *EvDepCmd = static_cast<Command *>((*Ev)->getCommand()); |
963 |
| - if (!EvDepCmd) { |
964 |
| - continue; |
965 |
| - } |
966 |
| - // Handle event dependencies on any commands part of another active |
967 |
| - // fusion. |
968 |
| - if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) { |
969 |
| - printFusionWarning("Aborting fusion because of event dependency from a " |
970 |
| - "different fusion"); |
971 |
| - cancelFusion(EvDepCmd->getQueue(), ToEnqueue); |
972 |
| - } |
973 |
| - // Check if this command depends on the placeholder command for the fusion |
974 |
| - // itself participates in. |
975 |
| - if (EvDepCmd == FusionCmd) { |
976 |
| - Ev = Events.erase(Ev); |
977 |
| - dependsOnFusion = true; |
978 |
| - } else { |
979 |
| - ++Ev; |
| 958 | + if (isInFusionMode(QUniqueID)) { |
| 959 | + if (NewCmd->isFusable()) { |
| 960 | + auto *FusionCmd = findFusionList(QUniqueID)->second.get(); |
| 961 | + |
| 962 | + bool dependsOnFusion = false; |
| 963 | + for (auto Ev = Events.begin(); Ev != Events.end();) { |
| 964 | + auto *EvDepCmd = static_cast<Command *>((*Ev)->getCommand()); |
| 965 | + if (!EvDepCmd) { |
| 966 | + continue; |
| 967 | + } |
| 968 | + // Handle event dependencies on any commands part of another active |
| 969 | + // fusion. |
| 970 | + if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) { |
| 971 | + printFusionWarning( |
| 972 | + "Aborting fusion because of event dependency from a " |
| 973 | + "different fusion"); |
| 974 | + cancelFusion(EvDepCmd->getQueue(), ToEnqueue); |
| 975 | + } |
| 976 | + // Check if this command depends on the placeholder command for the |
| 977 | + // fusion itself participates in. |
| 978 | + if (EvDepCmd == FusionCmd) { |
| 979 | + Ev = Events.erase(Ev); |
| 980 | + dependsOnFusion = true; |
| 981 | + } else { |
| 982 | + ++Ev; |
| 983 | + } |
980 | 984 | }
|
981 |
| - } |
982 | 985 |
|
983 |
| - // If this command has an explicit event dependency on the placeholder |
984 |
| - // command for this fusion (because it used depends_on on the event returned |
985 |
| - // by submitting another kernel to this fusion earlier), add a dependency on |
986 |
| - // all the commands in the fusion list so far. |
987 |
| - if (dependsOnFusion) { |
988 |
| - for (auto *Cmd : FusionCmd->getFusionList()) { |
989 |
| - Events.push_back(Cmd->getEvent()); |
| 986 | + // If this command has an explicit event dependency on the placeholder |
| 987 | + // command for this fusion (because it used depends_on on the event |
| 988 | + // returned by submitting another kernel to this fusion earlier), add a |
| 989 | + // dependency on all the commands in the fusion list so far. |
| 990 | + if (dependsOnFusion) { |
| 991 | + for (auto *Cmd : FusionCmd->getFusionList()) { |
| 992 | + Events.push_back(Cmd->getEvent()); |
| 993 | + } |
990 | 994 | }
|
991 |
| - } |
992 | 995 |
|
993 |
| - // Add the kernel to the graph, but delay the enqueue of any auxiliary |
994 |
| - // commands (e.g., allocations) resulting from that process by adding them |
995 |
| - // to the list of auxiliary commands of the fusion command. |
996 |
| - createGraphForCommand(NewCmd.get(), NewCmd->getCG(), |
997 |
| - isInteropHostTask(NewCmd.get()), Reqs, Events, Queue, |
998 |
| - FusionCmd->auxiliaryCommands()); |
999 |
| - |
1000 |
| - // Set the fusion command, so we recognize when another command depends on a |
1001 |
| - // kernel in the fusion list. |
1002 |
| - FusionCmd->addToFusionList(NewCmd.get()); |
1003 |
| - NewCmd->MFusionCmd = FusionCmd; |
1004 |
| - std::vector<Command *> ToCleanUp; |
1005 |
| - // Add an event dependency from the fusion placeholder command to the new |
1006 |
| - // kernel. |
1007 |
| - auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp); |
1008 |
| - if (ConnectionCmd) { |
1009 |
| - FusionCmd->auxiliaryCommands().push_back(ConnectionCmd); |
| 996 | + // Add the kernel to the graph, but delay the enqueue of any auxiliary |
| 997 | + // commands (e.g., allocations) resulting from that process by adding them |
| 998 | + // to the list of auxiliary commands of the fusion command. |
| 999 | + createGraphForCommand(NewCmd.get(), NewCmd->getCG(), |
| 1000 | + isInteropHostTask(NewCmd.get()), Reqs, Events, |
| 1001 | + Queue, FusionCmd->auxiliaryCommands()); |
| 1002 | + |
| 1003 | + // Set the fusion command, so we recognize when another command depends on |
| 1004 | + // a kernel in the fusion list. |
| 1005 | + FusionCmd->addToFusionList(NewCmd.get()); |
| 1006 | + NewCmd->MFusionCmd = FusionCmd; |
| 1007 | + std::vector<Command *> ToCleanUp; |
| 1008 | + // Add an event dependency from the fusion placeholder command to the new |
| 1009 | + // kernel. |
| 1010 | + auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp); |
| 1011 | + if (ConnectionCmd) { |
| 1012 | + FusionCmd->auxiliaryCommands().push_back(ConnectionCmd); |
| 1013 | + } |
| 1014 | + return {NewCmd.release(), FusionCmd->getEvent(), false}; |
| 1015 | + } else { |
| 1016 | + std::string s; |
| 1017 | + std::stringstream ss(s); |
| 1018 | + ss << "Not fusing '" << NewCmd->getTypeString() |
| 1019 | + << "' command group. Can only fuse device kernel command groups."; |
| 1020 | + printFusionWarning(ss.str()); |
1010 | 1021 | }
|
1011 |
| - return {NewCmd.release(), FusionCmd->getEvent(), false}; |
1012 | 1022 | }
|
1013 | 1023 | createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
|
1014 | 1024 | isInteropHostTask(NewCmd.get()), Reqs, Events, Queue,
|
|
0 commit comments