Skip to content

Commit a305dba

Browse files
authored
Fix im2col with 32fp (#5286)
1 parent 1912211 commit a305dba

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

ggml-sycl.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8247,7 +8247,8 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
82478247
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
82488248
}
82498249

8250-
static void im2col_f32_f16(const float *x, sycl::half *dst, int offset_delta,
8250+
template <typename T>
8251+
static void im2col_kernel(const float *x, T *dst, int offset_delta,
82518252
int IW, int IH, int OW, int KW, int KH,
82528253
int pelements, int CHW, int s0, int s1, int p0,
82538254
int p1, int d0, int d1,
@@ -11019,7 +11020,8 @@ static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
1101911020
});
1102011021
}
1102111022

11022-
static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
11023+
template <typename T>
11024+
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
1102311025
int OW, int OH, int KW, int KH, int IC,
1102411026
int offset_delta, int s0, int s1, int p0,
1102511027
int p1, int d0, int d1,
@@ -11036,7 +11038,7 @@ static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
1103611038
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
1103711039
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
1103811040
[=](sycl::nd_item<3> item_ct1) {
11039-
im2col_f32_f16(x, dst, offset_delta, IW, IH, OW, KW, KH,
11041+
im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
1104011042
parallel_elements, (IC * KH * KW), s0, s1, p0,
1104111043
p1, d0, d1, item_ct1);
1104211044
});
@@ -12424,7 +12426,7 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
1242412426

1242512427
GGML_ASSERT(src0->type == GGML_TYPE_F16);
1242612428
GGML_ASSERT(src1->type == GGML_TYPE_F32);
12427-
GGML_ASSERT( dst->type == GGML_TYPE_F16);
12429+
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
1242812430

1242912431
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
1243012432
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@@ -12447,8 +12449,11 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
1244712449

1244812450
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
1244912451

12450-
im2col_f32_f16_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH,
12451-
IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
12452+
if (dst->type == GGML_TYPE_F16) {
12453+
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
12454+
} else {
12455+
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
12456+
}
1245212457

1245312458
(void) src0;
1245412459
(void) src0_dd;

0 commit comments

Comments
 (0)