Skip to content

Commit b19bd06

Browse files
SYCL : support non-contiguous tensors in binary ops (add, sub, etc) (#12399)
* sycl : support non-contiguous tensors in binary ops * sycl : silence unused variable warning --------- Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent 92a3913 commit b19bd06

File tree

1 file changed

+61
-26
lines changed

1 file changed

+61
-26
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
474474
int ne0, int ne1, int ne2, int ne3,
475475
int ne10, int ne11, int ne12, int ne13,
476476
/*int s0, */ int s1, int s2, int s3,
477+
/*int s00,*/ int s01, int s02, int s03,
477478
/*int s10,*/ int s11, int s12, int s13,
478479
const sycl::nd_item<3> &item_ct1) {
479480
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -495,9 +496,9 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
495496
const int i12 = i2 % ne12;
496497
const int i13 = i3 % ne13;
497498

498-
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
499+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
499500
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
500-
const size_t i_dst = i_src0;
501+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
501502

502503
const src0_t * src0_row = src0 + i_src0;
503504
const src1_t * src1_row = src1 + i_src1;
@@ -515,6 +516,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
515516
int ne0, int ne1, int ne2, int ne3,
516517
int ne10, int ne11, int ne12, int ne13,
517518
/*int s0, */ int s1, int s2, int s3,
519+
/*int s00,*/ int s01, int s02, int s03,
518520
/*int s10,*/ int s11, int s12, int s13,
519521
const sycl::nd_item<3> &item_ct1) {
520522

@@ -534,9 +536,9 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
534536
const int i12 = i2 % ne12;
535537
const int i13 = i3 % ne13;
536538

537-
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
539+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
538540
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
539-
const size_t i_dst = i_src0;
541+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
540542

541543
const src0_t * src0_row = src0 + i_src0;
542544
const src1_t * src1_row = src1 + i_src1;
@@ -566,9 +568,11 @@ struct bin_bcast_sycl {
566568
int nr[4] = { nr0, nr1, nr2, nr3 };
567569

568570
// collapse dimensions until first broadcast dimension
569-
int64_t cne0[] = {ne0, ne1, ne2, ne3};
571+
int64_t cne[] = {ne0, ne1, ne2, ne3};
572+
int64_t cne0[] = {ne00, ne01, ne02, ne03};
570573
int64_t cne1[] = {ne10, ne11, ne12, ne13};
571-
size_t cnb0[] = {nb0, nb1, nb2, nb3};
574+
size_t cnb[] = {nb0, nb1, nb2, nb3};
575+
size_t cnb0[] = {nb00, nb01, nb02, nb03};
572576
size_t cnb1[] = {nb10, nb11, nb12, nb13};
573577
auto collapse = [](int64_t cne[]) {
574578
cne[0] *= cne[1];
@@ -583,32 +587,41 @@ struct bin_bcast_sycl {
583587
cnb[3] *= cne[3];
584588
};
585589

586-
for (int i = 0; i < 4; i++) {
587-
if (nr[i] != 1) {
588-
break;
589-
}
590-
if (i > 0) {
591-
collapse_nb(cnb0, cne0);
592-
collapse_nb(cnb1, cne1);
593-
collapse(cne0);
594-
collapse(cne1);
590+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
591+
for (int i = 0; i < 4; i++) {
592+
if (nr[i] != 1) {
593+
break;
594+
}
595+
if (i > 0) {
596+
collapse_nb(cnb, cne);
597+
collapse_nb(cnb0, cne0);
598+
collapse_nb(cnb1, cne1);
599+
collapse(cne);
600+
collapse(cne0);
601+
collapse(cne1);
602+
}
595603
}
596604
}
597605
{
598-
int64_t ne0 = cne0[0];
599-
int64_t ne1 = cne0[1];
600-
int64_t ne2 = cne0[2];
601-
int64_t ne3 = cne0[3];
606+
int64_t ne0 = cne[0];
607+
int64_t ne1 = cne[1];
608+
int64_t ne2 = cne[2];
609+
int64_t ne3 = cne[3];
602610

603611
int64_t ne10 = cne1[0];
604612
int64_t ne11 = cne1[1];
605613
int64_t ne12 = cne1[2];
606614
int64_t ne13 = cne1[3];
607615

608-
size_t nb0 = cnb0[0];
609-
size_t nb1 = cnb0[1];
610-
size_t nb2 = cnb0[2];
611-
size_t nb3 = cnb0[3];
616+
size_t nb0 = cnb[0];
617+
size_t nb1 = cnb[1];
618+
size_t nb2 = cnb[2];
619+
size_t nb3 = cnb[3];
620+
621+
size_t nb00 = cnb0[0];
622+
size_t nb01 = cnb0[1];
623+
size_t nb02 = cnb0[2];
624+
size_t nb03 = cnb0[3];
612625

613626
size_t nb10 = cnb1[0];
614627
size_t nb11 = cnb1[1];
@@ -625,6 +638,28 @@ struct bin_bcast_sycl {
625638
size_t s12 = nb12 / sizeof(src1_t);
626639
size_t s13 = nb13 / sizeof(src1_t);
627640

641+
size_t s00 = nb00 / sizeof(src0_t);
642+
size_t s01 = nb01 / sizeof(src0_t);
643+
size_t s02 = nb02 / sizeof(src0_t);
644+
size_t s03 = nb03 / sizeof(src0_t);
645+
646+
GGML_UNUSED(s00);
647+
648+
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
649+
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
650+
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
651+
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
652+
653+
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
654+
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
655+
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
656+
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
657+
658+
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
659+
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
660+
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
661+
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
662+
628663
GGML_ASSERT(s0 == 1);
629664
GGML_ASSERT(s10 == 1);
630665

@@ -661,8 +696,8 @@ struct bin_bcast_sycl {
661696
[=](sycl::nd_item<3> item_ct1) {
662697
k_bin_bcast_unravel<bin_op>(
663698
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
664-
ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
665-
s13, item_ct1);
699+
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
700+
s03, s11, s12, s13, item_ct1);
666701
});
667702
}
668703
} else {
@@ -680,7 +715,7 @@ struct bin_bcast_sycl {
680715
[=](sycl::nd_item<3> item_ct1) {
681716
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
682717
ne2, ne3, ne10, ne11, ne12, ne13,
683-
s1, s2, s3, s11, s12, s13,
718+
s1, s2, s3, s01, s02, s03, s11, s12, s13,
684719
item_ct1);
685720
});
686721
}

0 commit comments

Comments
 (0)