File tree Expand file tree Collapse file tree 1 file changed +28
-1
lines changed Expand file tree Collapse file tree 1 file changed +28
-1
lines changed Original file line number Diff line number Diff line change @@ -3810,11 +3810,38 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
3810
3810
}
3811
3811
}
3812
3812
3813
+ #ifdef GGML_SYCL_GRAPH
3814
+ static bool check_node_graph_compatibility (ggml_cgraph * cgraph) {
3815
+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
3816
+ ggml_tensor * node = cgraph->nodes [i];
3817
+ switch (node->op ) {
3818
+ default :
3819
+ break ;
3820
+ case GGML_OP_CONCAT:
3821
+ // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
3822
+ // but wait() can't be called on the events returned by a queue recording
3823
+ // to a graph.
3824
+ [[fallthrough]];
3825
+ case GGML_OP_MUL_MAT_ID:
3826
+ // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
3827
+ // submitting a memcpy operation, but wait() can't be called on a queue that
3828
+ // is recording to a graph.
3829
+ # ifndef NDEBUG
3830
+ GGML_LOG_DEBUG (" %s: disabling SYCL graphs due to unsupported node type\n " , __func__);
3831
+ # endif
3832
+ return false ;
3833
+ }
3834
+ }
3835
+ return true ;
3836
+ }
3837
+ #endif
3838
+
3813
3839
static ggml_status ggml_backend_sycl_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
3814
3840
auto * sycl_ctx = static_cast <ggml_backend_sycl_context *>(backend->context );
3815
3841
3816
3842
#ifdef GGML_SYCL_GRAPH
3817
- if (!g_ggml_sycl_disable_graph) {
3843
+ bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_node_graph_compatibility (cgraph);
3844
+ if (use_sycl_graph) {
3818
3845
const bool graph_support = dpct::get_device (sycl_ctx->device ).has (sycl::aspect::ext_oneapi_limited_graph);
3819
3846
if (!graph_support) {
3820
3847
GGML_SYCL_DEBUG (" [SYCL-GRAPH] can not use graphs on device:%d\n " , sycl_ctx->device );
You can’t perform that action at this time.
0 commit comments