@@ -8247,7 +8247,8 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
8247
8247
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
8248
8248
}
8249
8249
8250
- static void im2col_f32_f16(const float *x, sycl::half *dst, int offset_delta,
8250
+ template <typename T>
8251
+ static void im2col_kernel(const float *x, T *dst, int offset_delta,
8251
8252
int IW, int IH, int OW, int KW, int KH,
8252
8253
int pelements, int CHW, int s0, int s1, int p0,
8253
8254
int p1, int d0, int d1,
@@ -11019,7 +11020,8 @@ static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
11019
11020
});
11020
11021
}
11021
11022
11022
- static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
11023
+ template <typename T>
11024
+ static void im2col_sycl(const float *x, T *dst, int IW, int IH,
11023
11025
int OW, int OH, int KW, int KH, int IC,
11024
11026
int offset_delta, int s0, int s1, int p0,
11025
11027
int p1, int d0, int d1,
@@ -11036,7 +11038,7 @@ static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
11036
11038
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
11037
11039
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
11038
11040
[=](sycl::nd_item<3> item_ct1) {
11039
- im2col_f32_f16 (x, dst, offset_delta, IW, IH, OW, KW, KH,
11041
+ im2col_kernel (x, dst, offset_delta, IW, IH, OW, KW, KH,
11040
11042
parallel_elements, (IC * KH * KW), s0, s1, p0,
11041
11043
p1, d0, d1, item_ct1);
11042
11044
});
@@ -12424,7 +12426,7 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
12424
12426
12425
12427
GGML_ASSERT(src0->type == GGML_TYPE_F16);
12426
12428
GGML_ASSERT(src1->type == GGML_TYPE_F32);
12427
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
12429
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32 );
12428
12430
12429
12431
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12430
12432
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@@ -12447,8 +12449,11 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
12447
12449
12448
12450
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
12449
12451
12450
- im2col_f32_f16_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH,
12451
- IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
12452
+ if (dst->type == GGML_TYPE_F16) {
12453
+ im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
12454
+ } else {
12455
+ im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
12456
+ }
12452
12457
12453
12458
(void) src0;
12454
12459
(void) src0_dd;
0 commit comments