@@ -1463,24 +1463,6 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1463
1463
}
1464
1464
}
1465
1465
1466
-
1467
- static void diag_mask_inf_f32 (const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
1468
- const sycl::nd_item<3 > &item_ct1) {
1469
- const int col = item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
1470
- item_ct1.get_local_id (1 );
1471
- const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
1472
- item_ct1.get_local_id (2 );
1473
-
1474
- if (col >= ncols) {
1475
- return ;
1476
- }
1477
-
1478
- const int i = row*ncols + col;
1479
- // dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
1480
- // dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
1481
- dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
1482
- }
1483
-
1484
1466
static void scale_f32 (const float * x, float * dst, const float scale, const int k,
1485
1467
const sycl::nd_item<3 > &item_ct1) {
1486
1468
const int i = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
@@ -1666,21 +1648,6 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
1666
1648
});
1667
1649
}
1668
1650
1669
- static void diag_mask_inf_f32_sycl (const float *x, float *dst,
1670
- const int ncols_x, const int nrows_x,
1671
- const int rows_per_channel, const int n_past,
1672
- queue_ptr stream) {
1673
- const sycl::range<3 > block_dims (1 , SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1 );
1674
- const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1 ) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
1675
- const sycl::range<3 > block_nums (1 , block_num_x, nrows_x);
1676
- stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims),
1677
- [=](sycl::nd_item<3 > item_ct1) {
1678
- diag_mask_inf_f32 (x, dst, ncols_x,
1679
- rows_per_channel, n_past,
1680
- item_ct1);
1681
- });
1682
- }
1683
-
1684
1651
static dpct::err0 ggml_sycl_cpy_tensor_2d (void *dst,
1685
1652
const struct ggml_tensor *src,
1686
1653
int64_t i3, int64_t i2,
@@ -1962,24 +1929,6 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
1962
1929
sum_rows_f32_sycl (src0_dd, dst_dd, ncols, nrows, main_stream);
1963
1930
}
1964
1931
1965
- inline void ggml_sycl_op_diag_mask_inf (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1966
-
1967
- GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
1968
- GGML_ASSERT (dst->type == GGML_TYPE_F32);
1969
- GGML_ASSERT (!ggml_backend_buffer_is_sycl_split (dst->buffer ));
1970
-
1971
- const int64_t ne00 = dst->src [0 ]->ne [0 ];
1972
- const int64_t ne01 = dst->src [0 ]->ne [1 ];
1973
- const int nrows0 = ggml_nrows (dst->src [0 ]);
1974
-
1975
- const int n_past = ((int32_t *) dst->op_params )[0 ];
1976
- dpct::queue_ptr main_stream = ctx.stream ();
1977
- const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
1978
- float * dst_dd = static_cast <float *>(dst->data );
1979
-
1980
- diag_mask_inf_f32_sycl (src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
1981
- }
1982
-
1983
1932
inline void ggml_sycl_op_scale (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1984
1933
1985
1934
GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
@@ -2957,10 +2906,6 @@ static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2957
2906
ggml_sycl_cpy (ctx, dst->src [0 ], dst);
2958
2907
}
2959
2908
2960
- static void ggml_sycl_diag_mask_inf (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2961
- ggml_sycl_op_diag_mask_inf (ctx, dst);
2962
- }
2963
-
2964
2909
static void ggml_sycl_rope (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2965
2910
GGML_ASSERT (ggml_is_contiguous (dst->src [0 ])); // TODO: this restriction is temporary until non-cont support is implemented
2966
2911
ggml_sycl_op_rope (ctx, dst);
0 commit comments