Skip to content

Commit d32835c

Browse files
committed
softmax: handle SYCL exceptions and add debug logs
1 parent f4014a6 commit d32835c

File tree

4 files changed

+21
-8
lines changed

4 files changed

+21
-8
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2752,7 +2752,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
27522752
ggml_sycl_group_norm(ctx, dst);
27532753
break;
27542754
case GGML_OP_CONCAT:
2755-
ggml_sycl_op_concat(ctx, dst);
2755+
ggml_sycl_concat(ctx, dst);
27562756
break;
27572757
case GGML_OP_UPSCALE:
27582758
ggml_sycl_upscale(ctx, dst);
@@ -2817,7 +2817,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
28172817
ggml_sycl_diag_mask_inf(ctx, dst);
28182818
break;
28192819
case GGML_OP_SOFT_MAX:
2820-
ggml_sycl_op_soft_max(ctx, dst);
2820+
ggml_sycl_softmax(ctx, dst);
28212821
break;
28222822
case GGML_OP_ROPE:
28232823
ggml_sycl_rope(ctx, dst);

ggml/src/ggml-sycl/softmax.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ static void soft_max_f32_sycl(const float * x, const T * mask,
224224
}
225225
}
226226

227-
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
227+
static void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
228228

229229
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
230230
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -249,13 +249,26 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
249249

250250
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
251251
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
252+
GGML_SYCL_DEBUG("%s: Mask precision: F16\n", __func__);
252253
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
253254
main_stream, ctx.device);
254255
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
255256
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
257+
GGML_SYCL_DEBUG("%s: Mask precision: F32\n", __func__);
256258
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
257259
} else {
258260
/* mask unavailable */
259-
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
261+
GGML_SYCL_DEBUG("%s: No mask supplied\n", __func__);
262+
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream,
263+
ctx.device);
260264
}
265+
} catch (const sycl::exception & exc) {
266+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
267+
std::exit(1);
268+
}
269+
270+
void ggml_sycl_softmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
271+
GGML_SYCL_DEBUG("call %s\n", __func__);
272+
ggml_sycl_op_soft_max(ctx, dst);
273+
GGML_SYCL_DEBUG("call %s done\n", __func__);
261274
}

ggml/src/ggml-sycl/softmax.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515

1616
#include "common.hpp"
1717

18-
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
18+
void ggml_sycl_softmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
1919

2020
#endif // GGML_SYCL_SOFTMAX_HPP

ggml/src/ggml-sycl/sum.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ static void sum_rows_f32_sycl(const float * x, float * dst, const int ncols, con
2727
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
2828
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2929
GGML_ASSERT(dst->type == GGML_TYPE_F32);
30-
GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
30+
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
3131

3232
const int64_t ne = ggml_nelements(dst->src[0]);
3333
dpct::queue_ptr main_stream = ctx.stream();
@@ -43,7 +43,7 @@ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
4343
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
4444
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
4545
GGML_ASSERT(dst->type == GGML_TYPE_F32);
46-
GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
46+
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
4747

4848
const int64_t ncols = dst->src[0]->ne[0];
4949
const int64_t nrows = ggml_nrows(dst->src[0]);
@@ -68,5 +68,5 @@ void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
6868
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
6969
GGML_SYCL_DEBUG("call %s\n", __func__);
7070
ggml_sycl_op_sum_rows(ctx, dst);
71-
GML_SYCL_DEBUG("call %s done\n", __func__);
71+
GGML_SYCL_DEBUG("call %s done\n", __func__);
7272
}

0 commit comments

Comments
 (0)