@@ -105,6 +105,42 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
105
105
tmp);
106
106
}
107
107
}
108
+
109
+ template <int vec_sz>
110
+ sycl::vec<resT, vec_sz> operator ()(const sycl::vec<argT1, vec_sz> &in1,
111
+ const argT2 &in2) const
112
+ {
113
+ auto tmp = in1 + in2;
114
+ if constexpr (std::is_same_v<resT,
115
+ typename decltype (tmp)::element_type>)
116
+ {
117
+ return tmp;
118
+ }
119
+ else {
120
+ using dpctl::tensor::type_utils::vec_cast;
121
+
122
+ return vec_cast<resT, typename decltype (tmp)::element_type, vec_sz>(
123
+ tmp);
124
+ }
125
+ }
126
+
127
+ template <int vec_sz>
128
+ sycl::vec<resT, vec_sz>
129
+ operator ()(const argT1 &in1, const sycl::vec<argT2, vec_sz> &in2) const
130
+ {
131
+ auto tmp = in1 + in2;
132
+ if constexpr (std::is_same_v<resT,
133
+ typename decltype (tmp)::element_type>)
134
+ {
135
+ return tmp;
136
+ }
137
+ else {
138
+ using dpctl::tensor::type_utils::vec_cast;
139
+
140
+ return vec_cast<resT, typename decltype (tmp)::element_type, vec_sz>(
141
+ tmp);
142
+ }
143
+ }
108
144
};
109
145
110
146
template <typename argT1,
@@ -393,6 +429,126 @@ struct AddContigRowContigMatrixBroadcastFactory
393
429
}
394
430
};
395
431
432
+ template <typename argT1,
433
+ typename argT2,
434
+ typename resT,
435
+ unsigned int vec_sz = 4 ,
436
+ unsigned int n_vecs = 2 ,
437
+ bool enable_sg_loadstore = true >
438
+ using AddScalarContigArrayFunctor =
439
+ elementwise_common::BinaryScalarContigArrayFunctor<
440
+ argT1,
441
+ argT2,
442
+ resT,
443
+ AddFunctor<argT1, argT2, resT>,
444
+ vec_sz,
445
+ n_vecs,
446
+ enable_sg_loadstore>;
447
+
448
+ template <typename argT1,
449
+ typename argT2,
450
+ typename resT,
451
+ unsigned int vec_sz = 4 ,
452
+ unsigned int n_vecs = 2 ,
453
+ bool enable_sg_loadstore = true >
454
+ using AddContigArrayScalarFunctor =
455
+ elementwise_common::BinaryContigArrayScalarFunctor<
456
+ argT1,
457
+ argT2,
458
+ resT,
459
+ AddFunctor<argT1, argT2, resT>,
460
+ vec_sz,
461
+ n_vecs,
462
+ enable_sg_loadstore>;
463
+
464
+ template <typename argT1,
465
+ typename argT2,
466
+ typename resT,
467
+ unsigned int vec_sz,
468
+ unsigned int n_vecs>
469
+ class add_scalar_contig_array_kernel ;
470
+
471
+ template <typename argTy1, typename argTy2>
472
+ sycl::event
473
+ add_scalar_contig_array_impl (sycl::queue &exec_q,
474
+ size_t nelems,
475
+ const char *arg1_p,
476
+ ssize_t arg1_offset,
477
+ const char *arg2_p,
478
+ ssize_t arg2_offset,
479
+ char *res_p,
480
+ ssize_t res_offset,
481
+ const std::vector<sycl::event> &depends = {})
482
+ {
483
+ return elementwise_common::binary_scalar_contig_array_impl<
484
+ argTy1, argTy2, AddOutputType, AddScalarContigArrayFunctor,
485
+ add_scalar_contig_array_kernel>(exec_q, nelems, arg1_p, arg1_offset,
486
+ arg2_p, arg2_offset, res_p, res_offset,
487
+ depends);
488
+ }
489
+
490
+ template <typename argT1,
491
+ typename argT2,
492
+ typename resT,
493
+ unsigned int vec_sz,
494
+ unsigned int n_vecs>
495
+ class add_contig_array_scalar_kernel ;
496
+
497
+ template <typename argTy1, typename argTy2>
498
+ sycl::event
499
+ add_contig_array_scalar_impl (sycl::queue &exec_q,
500
+ size_t nelems,
501
+ const char *arg1_p,
502
+ ssize_t arg1_offset,
503
+ const char *arg2_p,
504
+ ssize_t arg2_offset,
505
+ char *res_p,
506
+ ssize_t res_offset,
507
+ const std::vector<sycl::event> &depends = {})
508
+ {
509
+ return elementwise_common::binary_contig_array_scalar_impl<
510
+ argTy1, argTy2, AddOutputType, AddContigArrayScalarFunctor,
511
+ add_contig_array_scalar_kernel>(exec_q, nelems, arg1_p, arg1_offset,
512
+ arg2_p, arg2_offset, res_p, res_offset,
513
+ depends);
514
+ }
515
+
516
+ template <typename fnT, typename T1, typename T2>
517
+ struct AddScalarContigArrayFactory
518
+ {
519
+ fnT get ()
520
+ {
521
+ if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
522
+ void >)
523
+ {
524
+ fnT fn = nullptr ;
525
+ return fn;
526
+ }
527
+ else {
528
+ fnT fn = add_scalar_contig_array_impl<T1, T2>;
529
+ return fn;
530
+ }
531
+ }
532
+ };
533
+
534
+ template <typename fnT, typename T1, typename T2>
535
+ struct AddContigArrayScalarFactory
536
+ {
537
+ fnT get ()
538
+ {
539
+ if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
540
+ void >)
541
+ {
542
+ fnT fn = nullptr ;
543
+ return fn;
544
+ }
545
+ else {
546
+ fnT fn = add_contig_array_scalar_impl<T1, T2>;
547
+ return fn;
548
+ }
549
+ }
550
+ };
551
+
396
552
template <typename argT, typename resT> struct AddInplaceFunctor
397
553
{
398
554
@@ -409,6 +565,12 @@ template <typename argT, typename resT> struct AddInplaceFunctor
409
565
{
410
566
res += in;
411
567
}
568
+
569
+ template <int vec_sz>
570
+ void operator ()(sycl::vec<resT, vec_sz> &res, const argT &in)
571
+ {
572
+ res += in;
573
+ }
412
574
};
413
575
414
576
template <typename argT,
@@ -606,6 +768,58 @@ struct AddInplaceRowMatrixBroadcastFactory
606
768
}
607
769
};
608
770
771
+ template <typename argT,
772
+ typename resT,
773
+ unsigned int vec_sz = 4 ,
774
+ unsigned int n_vecs = 2 ,
775
+ bool enable_sg_loadstore = true >
776
+ using AddInplaceScalarContigFunctor =
777
+ elementwise_common::BinaryInplaceScalarContigFunctor<
778
+ argT,
779
+ resT,
780
+ AddInplaceFunctor<argT, resT>,
781
+ vec_sz,
782
+ n_vecs,
783
+ enable_sg_loadstore>;
784
+
785
+ template <typename argT,
786
+ typename resT,
787
+ unsigned int vec_sz,
788
+ unsigned int n_vecs>
789
+ class add_inplace_scalar_contig_kernel ;
790
+
791
+ template <typename argTy, typename resTy>
792
+ sycl::event
793
+ add_inplace_scalar_contig_impl (sycl::queue &exec_q,
794
+ size_t nelems,
795
+ const char *arg_p,
796
+ ssize_t arg_offset,
797
+ char *res_p,
798
+ ssize_t res_offset,
799
+ const std::vector<sycl::event> &depends = {})
800
+ {
801
+ return elementwise_common::binary_inplace_scalar_contig_impl<
802
+ argTy, resTy, AddInplaceScalarContigFunctor,
803
+ add_inplace_scalar_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
804
+ res_p, res_offset, depends);
805
+ }
806
+
807
+ template <typename fnT, typename T1, typename T2>
808
+ struct AddInplaceScalarContigFactory
809
+ {
810
+ fnT get ()
811
+ {
812
+ if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
813
+ fnT fn = nullptr ;
814
+ return fn;
815
+ }
816
+ else {
817
+ fnT fn = add_inplace_scalar_contig_impl<T1, T2>;
818
+ return fn;
819
+ }
820
+ }
821
+ };
822
+
609
823
} // namespace add
610
824
} // namespace kernels
611
825
} // namespace tensor
0 commit comments