Skip to content

Commit e79bfca

Browse files
author
Aidan
committed
Update SYCL upscale operation
1 parent 213e90e commit e79bfca

File tree

1 file changed

+35
-27
lines changed

1 file changed

+35
-27
lines changed

ggml-sycl.cpp

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3847,21 +3847,28 @@ static void concat_f32(const float *x,const float *y, float *dst, const int ne
38473847
}
38483848
}
38493849

3850-
static void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor,
3851-
const sycl::nd_item<3> &item_ct1) {
3852-
int ne0 = ne00 * scale_factor;
3853-
int nidx = item_ct1.get_local_id(2) +
3854-
item_ct1.get_group(2) * item_ct1.get_local_range(2);
3855-
if (nidx >= ne0) {
3850+
static void upscale_f32(const float *x, float *dst,
3851+
const int nb00, const int nb01, const int nb02, const int nb03,
3852+
const int ne10, const int ne11, const int ne12, const int ne13,
3853+
const float sf0, const float sf1, const float sf2, const float sf3,
3854+
const sycl::nd_item<1> &item_ct1) {
3855+
int index = item_ct1.get_local_id(0) +
3856+
item_ct1.get_group(0) * item_ct1.get_local_range(0);
3857+
if (index >= ne10 * ne11 * ne12 * ne13) {
38563858
return;
38573859
}
38583860
// operation
3859-
int i00 = nidx / scale_factor;
3860-
int i01 = item_ct1.get_group(1) / scale_factor;
3861-
int offset_src = i00 + i01 * ne00 + item_ct1.get_group(0) * nb02;
3862-
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
3863-
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
3864-
dst[offset_dst] = x[offset_src];
3861+
int i10 = index % ne10;
3862+
int i11 = (index / ne10) % ne11;
3863+
int i12 = (index / (ne10 * ne11)) % ne12;
3864+
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
3865+
3866+
int i00 = i10 / sf0;
3867+
int i01 = i11 / sf1;
3868+
int i02 = i12 / sf2;
3869+
int i03 = i13 / sf3;
3870+
3871+
dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
38653872
}
38663873

38673874
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
@@ -10085,18 +10092,18 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
1008510092
});
1008610093
}
1008710094

10088-
static void upscale_f32_sycl(const float *x, float *dst, const int ne00,
10089-
const int ne01, const int ne02,
10090-
const int scale_factor, dpct::queue_ptr stream) {
10091-
int ne0 = (ne00 * scale_factor);
10092-
int num_blocks = (ne0 + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
10093-
sycl::range<3> gridDim(ne02, (ne01 * scale_factor), num_blocks);
10095+
static void upscale_f32_sycl(const float *x, float *dst,
10096+
const int nb00, const int nb01, const int nb02, const int nb03,
10097+
const int ne10, const int ne11, const int ne12, const int ne13,
10098+
const float sf0, const float sf1, const float sf2, const float sf3,
10099+
dpct::queue_ptr stream) {
10100+
int dst_size = ne10 * ne11 * ne12 * ne13;
10101+
int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
10102+
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
1009410103
stream->parallel_for(
10095-
sycl::nd_range<3>(gridDim *
10096-
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
10097-
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
10098-
[=](sycl::nd_item<3> item_ct1) {
10099-
upscale_f32(x, dst, ne00, ne00 * ne01, scale_factor, item_ct1);
10104+
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
10105+
[=](sycl::nd_item<1> item_ct1) {
10106+
upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
1010010107
});
1010110108
}
1010210109

@@ -13985,15 +13992,16 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0,
1398513992

1398613993
GGML_ASSERT(src0->type == GGML_TYPE_F32);
1398713994
GGML_ASSERT(dst->type == GGML_TYPE_F32);
13988-
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
1398913995

1399013996
#pragma message("TODO: generalize upscale operator")
1399113997
#pragma message(" https://github.com/ggerganov/ggml/pull/814")
13992-
GGML_ASSERT(false && "TODO: generalize upscale operator");
1399313998

13994-
const int scale_factor = dst->op_params[0];
13999+
const float sf0 = (float)dst->ne[0]/src0->ne[0];
14000+
const float sf1 = (float)dst->ne[1]/src0->ne[1];
14001+
const float sf2 = (float)dst->ne[2]/src0->ne[2];
14002+
const float sf3 = (float)dst->ne[3]/src0->ne[3];
1399514003

13996-
upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
14004+
upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, main_stream);
1399714005

1399814006
(void) src1;
1399914007
(void) dst;

0 commit comments

Comments
 (0)