|
81 | 81 | #if defined(GGML_USE_HIPBLAS)
|
82 | 82 | #define __CUDA_ARCH__ 1300
|
83 | 83 |
|
| 84 | +#ifndef __has_builtin |
| 85 | + #define __has_builtin(x) 0 |
| 86 | +#endif |
| 87 | + |
84 | 88 | typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
85 | 89 | static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
86 | 90 | const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
87 | 91 | const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
| 92 | +#if __has_builtin(__builtin_elementwise_sub_sat) |
88 | 93 | const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
89 | 94 | return reinterpret_cast<const int&>(c);
|
| 95 | +#else |
| 96 | + int8x4_t c; |
| 97 | + int16_t tmp; |
| 98 | +#pragma unroll |
| 99 | + for (int i = 0; i < 4; i++) { |
| 100 | + tmp = va[i] - vb[i]; |
| 101 | + if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max(); |
| 102 | + if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min(); |
| 103 | + c[i] = tmp; |
| 104 | + } |
| 105 | + return reinterpret_cast<int&>(c); |
| 106 | +#endif // __has_builtin(__builtin_elementwise_sub_sat) |
90 | 107 | }
|
91 | 108 |
|
92 | 109 | static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
@@ -447,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
447 | 464 | dst[i] = x[i] / (1.0f + expf(-x[i]));
|
448 | 465 | }
|
449 | 466 |
|
| 467 | +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { |
| 468 | +#pragma unroll |
| 469 | + for (int mask = 16; mask > 0; mask >>= 1) { |
| 470 | + a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); |
| 471 | + a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); |
| 472 | + } |
| 473 | + return a; |
| 474 | +} |
| 475 | + |
| 476 | +template <int block_size> |
450 | 477 | static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
451 | 478 | const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
452 | 479 | const int tid = threadIdx.x;
|
453 | 480 |
|
454 | 481 | const float eps = 1e-5f;
|
455 | 482 |
|
456 |
| - float mean = 0.0f; |
457 |
| - float var = 0.0f; |
| 483 | + float2 mean_var = make_float2(0.f, 0.f); |
458 | 484 |
|
459 |
| - for (int col = tid; col < ncols; col += WARP_SIZE) { |
| 485 | + for (int col = tid; col < ncols; col += block_size) { |
460 | 486 | const float xi = x[row*ncols + col];
|
461 |
| - mean += xi; |
462 |
| - var += xi * xi; |
| 487 | + mean_var.x += xi; |
| 488 | + mean_var.y += xi * xi; |
463 | 489 | }
|
464 | 490 |
|
465 | 491 | // sum up partial sums
|
466 |
| -#pragma unroll |
467 |
| - for (int mask = 16; mask > 0; mask >>= 1) { |
468 |
| - mean += __shfl_xor_sync(0xffffffff, mean, mask, 32); |
469 |
| - var += __shfl_xor_sync(0xffffffff, var, mask, 32); |
| 492 | + mean_var = warp_reduce_sum(mean_var); |
| 493 | + if (block_size > WARP_SIZE) { |
| 494 | + __shared__ float2 s_sum[32]; |
| 495 | + int warp_id = threadIdx.x / WARP_SIZE; |
| 496 | + int lane_id = threadIdx.x % WARP_SIZE; |
| 497 | + if (lane_id == 0) { |
| 498 | + s_sum[warp_id] = mean_var; |
| 499 | + } |
| 500 | + __syncthreads(); |
| 501 | + mean_var = s_sum[lane_id]; |
| 502 | + mean_var = warp_reduce_sum(mean_var); |
470 | 503 | }
|
471 | 504 |
|
472 |
| - mean /= ncols; |
473 |
| - var = var / ncols - mean * mean; |
474 |
| - const float inv_var = rsqrtf(var + eps); |
| 505 | + const float mean = mean_var.x / ncols; |
| 506 | + const float var = mean_var.y / ncols - mean * mean; |
| 507 | + const float inv_std = rsqrtf(var + eps); |
475 | 508 |
|
476 |
| - for (int col = tid; col < ncols; col += WARP_SIZE) { |
477 |
| - dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var; |
| 509 | + for (int col = tid; col < ncols; col += block_size) { |
| 510 | + dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; |
478 | 511 | }
|
479 | 512 | }
|
480 | 513 |
|
| 514 | +static __device__ __forceinline__ float warp_reduce_sum(float x) { |
| 515 | +#pragma unroll |
| 516 | + for (int mask = 16; mask > 0; mask >>= 1) { |
| 517 | + x += __shfl_xor_sync(0xffffffff, x, mask, 32); |
| 518 | + } |
| 519 | + return x; |
| 520 | +} |
| 521 | + |
| 522 | +template <int block_size> |
481 | 523 | static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
482 | 524 | const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
483 | 525 | const int tid = threadIdx.x;
|
484 | 526 |
|
485 | 527 | float tmp = 0.0f; // partial sum for thread in warp
|
486 | 528 |
|
487 |
| - for (int col = tid; col < ncols; col += WARP_SIZE) { |
| 529 | + for (int col = tid; col < ncols; col += block_size) { |
488 | 530 | const float xi = x[row*ncols + col];
|
489 | 531 | tmp += xi * xi;
|
490 | 532 | }
|
491 | 533 |
|
492 | 534 | // sum up partial sums
|
493 |
| -#pragma unroll |
494 |
| - for (int mask = 16; mask > 0; mask >>= 1) { |
495 |
| - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); |
| 535 | + tmp = warp_reduce_sum(tmp); |
| 536 | + if (block_size > WARP_SIZE) { |
| 537 | + __shared__ float s_sum[32]; |
| 538 | + int warp_id = threadIdx.x / WARP_SIZE; |
| 539 | + int lane_id = threadIdx.x % WARP_SIZE; |
| 540 | + if (lane_id == 0) { |
| 541 | + s_sum[warp_id] = tmp; |
| 542 | + } |
| 543 | + __syncthreads(); |
| 544 | + tmp = s_sum[lane_id]; |
| 545 | + tmp = warp_reduce_sum(tmp); |
496 | 546 | }
|
497 | 547 |
|
498 | 548 | const float mean = tmp / ncols;
|
499 | 549 | const float scale = rsqrtf(mean + eps);
|
500 | 550 |
|
501 |
| - for (int col = tid; col < ncols; col += WARP_SIZE) { |
| 551 | + for (int col = tid; col < ncols; col += block_size) { |
502 | 552 | dst[row*ncols + col] = scale * x[row*ncols + col];
|
503 | 553 | }
|
504 | 554 | }
|
@@ -4186,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
4186 | 4236 |
|
4187 | 4237 | static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
4188 | 4238 | GGML_ASSERT(ncols % WARP_SIZE == 0);
|
4189 |
| - const dim3 block_dims(WARP_SIZE, 1, 1); |
4190 |
| - norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols); |
| 4239 | + if (ncols < 1024) { |
| 4240 | + const dim3 block_dims(WARP_SIZE, 1, 1); |
| 4241 | + norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols); |
| 4242 | + } else { |
| 4243 | + const dim3 block_dims(1024, 1, 1); |
| 4244 | + norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols); |
| 4245 | + } |
4191 | 4246 | }
|
4192 | 4247 |
|
4193 | 4248 | static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
4194 | 4249 | GGML_ASSERT(ncols % WARP_SIZE == 0);
|
4195 |
| - const dim3 block_dims(WARP_SIZE, 1, 1); |
4196 |
| - rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps); |
| 4250 | + if (ncols < 1024) { |
| 4251 | + const dim3 block_dims(WARP_SIZE, 1, 1); |
| 4252 | + rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps); |
| 4253 | + } else { |
| 4254 | + const dim3 block_dims(1024, 1, 1); |
| 4255 | + rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps); |
| 4256 | + } |
4197 | 4257 | }
|
4198 | 4258 |
|
4199 | 4259 | static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
|
@@ -5721,7 +5781,6 @@ inline void ggml_cuda_op_alibi(
|
5721 | 5781 | (void) src1;
|
5722 | 5782 | (void) src0_ddq_i;
|
5723 | 5783 | (void) src1_ddf_i;
|
5724 |
| - (void) i02; |
5725 | 5784 | (void) i1;
|
5726 | 5785 | }
|
5727 | 5786 |
|
|
0 commit comments