Skip to content

Commit 4f8d19f

Browse files
authored
[SYCL] Fix SYCL im2col and convert Overflow with Large Dims (ggml-org#9052)
* sycl: fix im2col overflow and sync with cuda Signed-off-by: zhentaoyu <[email protected]> * sycl: fix convert overflow Signed-off-by: zhentaoyu <[email protected]> * sycl: fix convert and dequantize Signed-off-by: zhentaoyu <[email protected]> * sycl: fix ib in dmmv Signed-off-by: zhentaoyu <[email protected]> * sycl:refine convert Signed-off-by: zhentaoyu <[email protected]> * sycl: move downsample global_range into common Signed-off-by: zhentaoyu <[email protected]> * test: add im2col and convert test cases Signed-off-by: zhentaoyu <[email protected]> * test: make new cases only in sycl Signed-off-by: zhentaoyu <[email protected]> * test: comment new test_cases for only local testing Signed-off-by: zhentaoyu <[email protected]> --------- Signed-off-by: zhentaoyu <[email protected]>
1 parent 90db814 commit 4f8d19f

File tree

11 files changed

+333
-262
lines changed

11 files changed

+333
-262
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -893,43 +893,6 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
893893
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
894894
}
895895

896-
template <typename T>
897-
static void im2col_kernel(const float *x, T *dst, int offset_delta,
898-
int IW, int IH, int OW, int KW, int KH,
899-
int pelements, int CHW, int s0, int s1, int p0,
900-
int p1, int d0, int d1,
901-
const sycl::nd_item<3> &item_ct1) {
902-
const int i = item_ct1.get_local_id(2) +
903-
item_ct1.get_group(2) * item_ct1.get_local_range(2);
904-
if (i >= pelements) {
905-
return;
906-
}
907-
908-
const int ksize = OW * (KH > 1 ? KW : 1);
909-
const int kx = i / ksize;
910-
const int kd = kx * ksize;
911-
const int ky = (i - kd) / OW;
912-
const int ix = i % OW;
913-
914-
const int64_t iiw = ix * s0 + kx * d0 - p0;
915-
const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1;
916-
917-
const int64_t offset_dst =
918-
(item_ct1.get_group(1) * OW + ix) * CHW +
919-
(item_ct1.get_group(0) * (KW * KH) + ky * KW + kx);
920-
921-
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
922-
dst[offset_dst] =
923-
sycl::vec<float, 1>(0.0f)
924-
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
925-
} else {
926-
const int64_t offset_src = item_ct1.get_group(0) * offset_delta;
927-
dst[offset_dst] =
928-
sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
929-
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
930-
}
931-
}
932-
933896
template <typename Ti, typename To>
934897
static void pool2d_nchw_kernel(
935898
const int ih, const int iw, const int oh, const int ow,
@@ -1742,32 +1705,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
17421705
});
17431706
}
17441707

1745-
template <typename T>
1746-
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
1747-
int OW, int OH, int KW, int KH, int IC,
1748-
int offset_delta, int s0, int s1, int p0,
1749-
int p1, int d0, int d1,
1750-
queue_ptr stream) {
1751-
const int parallel_elements = OW * KW * KH;
1752-
const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
1753-
sycl::range<3> block_nums(IC, OH, num_blocks);
1754-
{
1755-
dpct::has_capability_or_fail(stream->get_device(),
1756-
{sycl::aspect::fp16});
1757-
1758-
stream->parallel_for(
1759-
sycl::nd_range<3>(block_nums *
1760-
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
1761-
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
1762-
[=](sycl::nd_item<3> item_ct1) {
1763-
im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
1764-
parallel_elements, (IC * KH * KW), s0, s1, p0,
1765-
p1, d0, d1, item_ct1);
1766-
});
1767-
}
1768-
}
1769-
1770-
17711708
static bool g_sycl_loaded = false;
17721709

17731710
bool ggml_sycl_loaded(void) {
@@ -2636,47 +2573,6 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
26362573
(void) src1_dd;
26372574
}
26382575

2639-
inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2640-
const ggml_tensor *src1, ggml_tensor *dst,
2641-
const float *src0_dd, const float *src1_dd,
2642-
float *dst_dd,
2643-
const queue_ptr &main_stream) {
2644-
2645-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
2646-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
2647-
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
2648-
2649-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
2650-
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
2651-
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
2652-
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
2653-
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
2654-
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
2655-
2656-
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
2657-
2658-
const int64_t IC = src1->ne[is_2D ? 2 : 1];
2659-
const int64_t IH = is_2D ? src1->ne[1] : 1;
2660-
const int64_t IW = src1->ne[0];
2661-
2662-
const int64_t KH = is_2D ? src0->ne[1] : 1;
2663-
const int64_t KW = src0->ne[0];
2664-
2665-
const int64_t OH = is_2D ? dst->ne[2] : 1;
2666-
const int64_t OW = dst->ne[1];
2667-
2668-
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
2669-
2670-
if (dst->type == GGML_TYPE_F16) {
2671-
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
2672-
} else {
2673-
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
2674-
}
2675-
2676-
(void) src0;
2677-
(void) src0_dd;
2678-
}
2679-
26802576
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
26812577
const ggml_tensor *src1, ggml_tensor *dst,
26822578
const float *src0_dd, const float *src1_dd,

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@
2525
#include "norm.hpp"
2626
#include "softmax.hpp"
2727
#include "tsembd.hpp"
28+
#include "im2col.hpp"
2829

2930
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/common.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try {
5151
<< ", line:" << __LINE__ << std::endl;
5252
std::exit(1);
5353
}
54+
55+
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
56+
const int64_t max_range = std::numeric_limits<int>::max();
57+
int64_t sycl_down_blk_size = block_size;
58+
int64_t global_range = accumulate_block_num * sycl_down_blk_size;
59+
while(global_range > max_range) {
60+
sycl_down_blk_size /= 2;
61+
global_range = accumulate_block_num * sycl_down_blk_size;
62+
}
63+
return sycl_down_blk_size;
64+
}

ggml/src/ggml-sycl/common.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,4 +352,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
352352
return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
353353
}
354354

355+
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
356+
355357
#endif // GGML_SYCL_COMMON_HPP

0 commit comments

Comments
 (0)