@@ -1982,7 +1982,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
1982
1982
1983
1983
const int64_t ne00 = src0->ne [0 ];
1984
1984
const int64_t ne10 = src1->ne [0 ];
1985
-
1985
+ GGML_ASSERT (ne00 == ne10);
1986
1986
1987
1987
const int64_t row_diff = row_high - row_low;
1988
1988
@@ -2727,10 +2727,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
2727
2727
GGML_ASSERT (!ggml_is_transposed (src1));
2728
2728
GGML_ASSERT (!ggml_backend_buffer_is_sycl_split (src0->buffer ));
2729
2729
GGML_ASSERT (src0->type == GGML_TYPE_F16);
2730
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
2730
2731
2731
2732
GGML_TENSOR_BINARY_OP_LOCALS
2732
2733
2733
-
2734
2734
SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
2735
2735
queue_ptr main_stream = ctx.stream ();;
2736
2736
@@ -2751,39 +2751,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
2751
2751
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
2752
2752
: src1_f16_alloc.get ();
2753
2753
2754
- char * dst_t ;
2755
-
2756
- dpct::library_data_t cu_compute_type = dpct::library_data_t ::real_float;
2757
- dpct::library_data_t cu_data_type = dpct::library_data_t ::real_float;
2758
-
2759
- // dst strides
2760
- size_t nbd2 = dst->nb [2 ];
2761
- size_t nbd3 = dst->nb [3 ];
2754
+ const dpct::library_data_t cu_compute_type = dpct::library_data_t ::real_float;
2755
+ const dpct::library_data_t cu_data_type = dpct::library_data_t ::real_float;
2762
2756
2763
2757
const float alpha_f32 = 1 .0f ;
2764
2758
const float beta_f32 = 0 .0f ;
2765
2759
2766
2760
const void * alpha = &alpha_f32;
2767
2761
const void * beta = &beta_f32;
2768
2762
2769
- dst_t = (char *) dst_ddf;
2763
+ char * dst_t = (char *) dst_ddf;
2770
2764
2771
2765
GGML_ASSERT (ne12 % ne02 == 0 );
2772
2766
GGML_ASSERT (ne13 % ne03 == 0 );
2767
+ GGML_ASSERT (ne01 == static_cast <int64_t >(nb1/nb0));
2768
+ GGML_ASSERT (ne10 == ne00);
2773
2769
2774
2770
// broadcast factors
2775
- const int64_t r2 = ne12/ne02;
2776
- const int64_t r3 = ne13/ne03;
2771
+ const auto r2 = ne12/ne02;
2772
+ const auto r3 = ne13/ne03;
2773
+ const auto ne23 = ne12*ne13;
2777
2774
2778
2775
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
2779
2776
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
2777
+ #ifdef GGML_SYCL_DNNL
2778
+ // TODO: use strided dnnl::memory::desc ctor in row_gemm to relax below assertions
2779
+ GGML_ASSERT (nb11/nb10 == ne10);
2780
+ GGML_ASSERT (nb01/nb00 == ne00);
2781
+
2782
+ DnnlGemmWrapper::row_gemm (ctx, false , true , ne11, ne01, ne10, src1_f16,
2783
+ DnnlGemmWrapper::to_dt<sycl::half>(), src0_as_f16, DnnlGemmWrapper::to_dt<sycl::half>(),
2784
+ dst_t , DnnlGemmWrapper::to_dt<float >(), main_stream, ne23);
2785
+ #else
2780
2786
SYCL_CHECK (CHECK_TRY_ERROR (dpct::gemm_batch (
2781
2787
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2782
2788
(const char *) src0_as_f16, dpct::library_data_t ::real_half, nb01 / nb00, nb02 / nb00,
2783
2789
(const char *) src1_f16, dpct::library_data_t ::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t ,
2784
- cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
2790
+ cu_data_type, ne01, nb2 / nb0, ne23, cu_compute_type)));
2791
+ #endif
2785
2792
} else {
2786
- const int ne23 = ne12*ne13;
2787
2793
2788
2794
ggml_sycl_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
2789
2795
ggml_sycl_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
@@ -2811,7 +2817,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
2811
2817
dst_t , ptrs_src_get,
2812
2818
ptrs_dst_get, ne12, ne13, ne23,
2813
2819
nb02, nb03, nb12_scaled, nb13_scaled,
2814
- nbd2, nbd3 , r2, r3, item_ct1);
2820
+ nb2, nb3 , r2, r3, item_ct1);
2815
2821
});
2816
2822
});
2817
2823
}
@@ -3651,7 +3657,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
3651
3657
return GGML_STATUS_SUCCESS;
3652
3658
}
3653
3659
3654
- sycl_ex::command_graph model_sycl_graph (*(sycl_ctx->stream ()));
3660
+ sycl_ex::command_graph model_sycl_graph (*(sycl_ctx->stream ()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
3661
+
3655
3662
model_sycl_graph.begin_recording (*(sycl_ctx->stream ()));
3656
3663
ggml_backend_sycl_graph_compute_impl (sycl_ctx, cgraph);
3657
3664
model_sycl_graph.end_recording ();
0 commit comments