Skip to content

Commit 8336595

Browse files
authored
ggml : sync latest llama.cpp (view_src + alloc improvements) (ggml-org#1247)
* ggml : sync latest llama.cpp (view_src + alloc improvements) * ggml : fix build
1 parent 30191cd commit 8336595

File tree

6 files changed

+855
-519
lines changed

6 files changed

+855
-519
lines changed

ggml-cuda.cu

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,29 @@
8181
#if defined(GGML_USE_HIPBLAS)
8282
#define __CUDA_ARCH__ 1300
8383

84+
#ifndef __has_builtin
85+
#define __has_builtin(x) 0
86+
#endif
87+
8488
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
8589
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
8690
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
8791
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
92+
#if __has_builtin(__builtin_elementwise_sub_sat)
8893
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
8994
return reinterpret_cast<const int&>(c);
95+
#else
96+
int8x4_t c;
97+
int16_t tmp;
98+
#pragma unroll
99+
for (int i = 0; i < 4; i++) {
100+
tmp = va[i] - vb[i];
101+
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
102+
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
103+
c[i] = tmp;
104+
}
105+
return reinterpret_cast<int&>(c);
106+
#endif // __has_builtin(__builtin_elementwise_sub_sat)
90107
}
91108

92109
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
@@ -447,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
447464
dst[i] = x[i] / (1.0f + expf(-x[i]));
448465
}
449466

467+
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
468+
#pragma unroll
469+
for (int mask = 16; mask > 0; mask >>= 1) {
470+
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
471+
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
472+
}
473+
return a;
474+
}
475+
476+
template <int block_size>
450477
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
451478
const int row = blockIdx.x*blockDim.y + threadIdx.y;
452479
const int tid = threadIdx.x;
453480

454481
const float eps = 1e-5f;
455482

456-
float mean = 0.0f;
457-
float var = 0.0f;
483+
float2 mean_var = make_float2(0.f, 0.f);
458484

459-
for (int col = tid; col < ncols; col += WARP_SIZE) {
485+
for (int col = tid; col < ncols; col += block_size) {
460486
const float xi = x[row*ncols + col];
461-
mean += xi;
462-
var += xi * xi;
487+
mean_var.x += xi;
488+
mean_var.y += xi * xi;
463489
}
464490

465491
// sum up partial sums
466-
#pragma unroll
467-
for (int mask = 16; mask > 0; mask >>= 1) {
468-
mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
469-
var += __shfl_xor_sync(0xffffffff, var, mask, 32);
492+
mean_var = warp_reduce_sum(mean_var);
493+
if (block_size > WARP_SIZE) {
494+
__shared__ float2 s_sum[32];
495+
int warp_id = threadIdx.x / WARP_SIZE;
496+
int lane_id = threadIdx.x % WARP_SIZE;
497+
if (lane_id == 0) {
498+
s_sum[warp_id] = mean_var;
499+
}
500+
__syncthreads();
501+
mean_var = s_sum[lane_id];
502+
mean_var = warp_reduce_sum(mean_var);
470503
}
471504

472-
mean /= ncols;
473-
var = var / ncols - mean * mean;
474-
const float inv_var = rsqrtf(var + eps);
505+
const float mean = mean_var.x / ncols;
506+
const float var = mean_var.y / ncols - mean * mean;
507+
const float inv_std = rsqrtf(var + eps);
475508

476-
for (int col = tid; col < ncols; col += WARP_SIZE) {
477-
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
509+
for (int col = tid; col < ncols; col += block_size) {
510+
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
478511
}
479512
}
480513

514+
static __device__ __forceinline__ float warp_reduce_sum(float x) {
515+
#pragma unroll
516+
for (int mask = 16; mask > 0; mask >>= 1) {
517+
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
518+
}
519+
return x;
520+
}
521+
522+
template <int block_size>
481523
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
482524
const int row = blockIdx.x*blockDim.y + threadIdx.y;
483525
const int tid = threadIdx.x;
484526

485527
float tmp = 0.0f; // partial sum for thread in warp
486528

487-
for (int col = tid; col < ncols; col += WARP_SIZE) {
529+
for (int col = tid; col < ncols; col += block_size) {
488530
const float xi = x[row*ncols + col];
489531
tmp += xi * xi;
490532
}
491533

492534
// sum up partial sums
493-
#pragma unroll
494-
for (int mask = 16; mask > 0; mask >>= 1) {
495-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
535+
tmp = warp_reduce_sum(tmp);
536+
if (block_size > WARP_SIZE) {
537+
__shared__ float s_sum[32];
538+
int warp_id = threadIdx.x / WARP_SIZE;
539+
int lane_id = threadIdx.x % WARP_SIZE;
540+
if (lane_id == 0) {
541+
s_sum[warp_id] = tmp;
542+
}
543+
__syncthreads();
544+
tmp = s_sum[lane_id];
545+
tmp = warp_reduce_sum(tmp);
496546
}
497547

498548
const float mean = tmp / ncols;
499549
const float scale = rsqrtf(mean + eps);
500550

501-
for (int col = tid; col < ncols; col += WARP_SIZE) {
551+
for (int col = tid; col < ncols; col += block_size) {
502552
dst[row*ncols + col] = scale * x[row*ncols + col];
503553
}
504554
}
@@ -4186,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
41864236

41874237
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
41884238
GGML_ASSERT(ncols % WARP_SIZE == 0);
4189-
const dim3 block_dims(WARP_SIZE, 1, 1);
4190-
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
4239+
if (ncols < 1024) {
4240+
const dim3 block_dims(WARP_SIZE, 1, 1);
4241+
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
4242+
} else {
4243+
const dim3 block_dims(1024, 1, 1);
4244+
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
4245+
}
41914246
}
41924247

41934248
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
41944249
GGML_ASSERT(ncols % WARP_SIZE == 0);
4195-
const dim3 block_dims(WARP_SIZE, 1, 1);
4196-
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
4250+
if (ncols < 1024) {
4251+
const dim3 block_dims(WARP_SIZE, 1, 1);
4252+
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
4253+
} else {
4254+
const dim3 block_dims(1024, 1, 1);
4255+
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
4256+
}
41974257
}
41984258

41994259
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
@@ -5721,7 +5781,6 @@ inline void ggml_cuda_op_alibi(
57215781
(void) src1;
57225782
(void) src0_ddq_i;
57235783
(void) src1_ddf_i;
5724-
(void) i02;
57255784
(void) i1;
57265785
}
57275786

0 commit comments

Comments
 (0)