Skip to content

Commit bcadd61

Browse files
authored
Merge branch 'master' into sycl-gemm-dispatch
2 parents 9ef4671 + 852aafb commit bcadd61

File tree

7 files changed

+197
-65
lines changed

7 files changed

+197
-65
lines changed

CMakePresets.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{
1+
{
22
"version": 4,
33
"configurePresets": [
44
{
@@ -40,6 +40,10 @@
4040

4141
{ "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] },
4242
{ "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "release" ] },
43-
{ "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "release", "static" ] }
43+
{ "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "release", "static" ] },
44+
45+
{ "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] },
46+
{ "name": "x64-windows-msvc-release", "inherits": [ "base", "release" ] },
47+
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "release", "static" ] }
4448
]
4549
}

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ endif # JETSON_EOL_MODULE_DETECT
441441
ifdef LLAMA_DEBUG
442442
MK_NVCCFLAGS += -lineinfo
443443
endif # LLAMA_DEBUG
444+
ifdef LLAMA_CUDA_DEBUG
445+
MK_NVCCFLAGS += --device-debug
446+
endif # LLAMA_CUDA_DEBUG
444447
ifdef LLAMA_CUDA_NVCC
445448
NVCC = $(CCACHE) $(LLAMA_CUDA_NVCC)
446449
else

ggml-cuda.cu

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ int ggml_cuda_get_device() {
119119
return id;
120120
}
121121

122+
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
123+
ggml_cuda_set_device(device);
124+
#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
125+
auto res = hipMallocManaged(ptr, size);
126+
if (res == hipSuccess) {
127+
// if error we "need" to know why...
128+
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
129+
}
130+
return res;
131+
#else
132+
return cudaMalloc(ptr, size);
133+
#endif
134+
}
135+
122136
static ggml_cuda_device_info ggml_cuda_init() {
123137
#ifdef __HIP_PLATFORM_AMD__
124138
// Workaround for a rocBLAS bug when using multiple graphics cards:
@@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
271285
size_t look_ahead_size = (size_t) (1.05 * size);
272286
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
273287
ggml_cuda_set_device(device);
274-
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
288+
CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
275289
*actual_size = look_ahead_size;
276290
pool_size += look_ahead_size;
277291
#ifdef DEBUG_CUDA_MALLOC
@@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
537551
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
538552

539553
void * dev_ptr;
540-
cudaError_t err = cudaMalloc(&dev_ptr, size);
554+
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
541555
if (err != cudaSuccess) {
542556
// clear the error
543557
cudaGetLastError();
@@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
798812
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
799813
ggml_cuda_set_device(id);
800814
char * buf;
801-
CUDA_CHECK(cudaMalloc(&buf, size));
815+
CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
802816

803817
// set padding to 0 to avoid possible NaN values
804818
if (size > original_size) {
@@ -2510,9 +2524,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25102524

25112525
bool use_cuda_graph = true;
25122526
bool cuda_graph_update_required = false;
2513-
// pointer to CUDA cpy kernel, which is required to identify
2527+
// vector of pointers to CUDA cpy kernels, which are required to identify
25142528
// kernel parameters which need updated in the graph for each token
2515-
void * ggml_cuda_cpy_fn_ptr = nullptr;
2529+
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
25162530

25172531
if (cuda_ctx->cuda_graph->graph == nullptr) {
25182532
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
@@ -2588,9 +2602,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25882602
if (node->op == GGML_OP_CPY) {
25892603
// store the copy op parameter which changes with each token.
25902604
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2591-
if (ggml_cuda_cpy_fn_ptr == nullptr) {
2592-
// store a pointer to the copy op CUDA kernel to identify it later
2593-
ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2605+
// store a pointer to each copy op CUDA kernel to identify it later
2606+
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2607+
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2608+
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
25942609
}
25952610
}
25962611

@@ -2720,7 +2735,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
27202735
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
27212736
int k = 0;
27222737
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2723-
if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
2738+
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
27242739
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
27252740
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
27262741
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));

ggml-cuda/common.cuh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,8 @@
7979
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
8080
#define cudaHostUnregister hipHostUnregister
8181
#define cudaLaunchHostFunc hipLaunchHostFunc
82-
#ifdef GGML_HIP_UMA
83-
#define cudaMalloc hipMallocManaged
84-
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
85-
#else
8682
#define cudaMalloc hipMalloc
8783
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
88-
#endif
8984
#define cudaMemcpy hipMemcpy
9085
#define cudaMemcpyAsync hipMemcpyAsync
9186
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync

ggml-metal.m

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
GGML_METAL_KERNEL_TYPE_MUL_ROW,
3636
GGML_METAL_KERNEL_TYPE_DIV,
3737
GGML_METAL_KERNEL_TYPE_DIV_ROW,
38+
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39+
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40+
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41+
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
3842
GGML_METAL_KERNEL_TYPE_SCALE,
3943
GGML_METAL_KERNEL_TYPE_SCALE_4,
4044
GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -184,9 +188,9 @@
184188
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185189
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186190
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
191+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
188192
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
193+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
190194
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191195
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192196
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -485,6 +489,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
485489
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
486490
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
487491
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
492+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
493+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
494+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
495+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
488496
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
489497
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
490498
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -634,9 +642,9 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
634642
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
635643
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
636644
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
637-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
645+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638646
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
639-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
647+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
640648
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
641649
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
642650
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746754
case GGML_OP_ACC:
747755
case GGML_OP_MUL:
748756
case GGML_OP_DIV:
757+
case GGML_OP_REPEAT:
749758
case GGML_OP_SCALE:
750759
case GGML_OP_CLAMP:
751760
case GGML_OP_SQR:
@@ -770,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
770779
case GGML_OP_LEAKY_RELU:
771780
return true;
772781
case GGML_OP_FLASH_ATTN_EXT:
782+
if (op->src[0]->ne[0] == 256) {
783+
return false;
784+
}
773785
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
774786
case GGML_OP_MUL_MAT:
775787
case GGML_OP_MUL_MAT_ID:
@@ -976,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
976988
switch (dst->op) {
977989
case GGML_OP_CONCAT:
978990
{
979-
const int64_t nb = ne00;
980-
981991
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
982992

983993
[encoder setComputePipelineState:pipeline];
@@ -1008,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
10081018
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
10091019
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
10101020
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1011-
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
10121021

10131022
const int nth = MIN(1024, ne0);
10141023

@@ -1018,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute(
10181027
case GGML_OP_MUL:
10191028
case GGML_OP_DIV:
10201029
{
1030+
GGML_ASSERT(src0t == GGML_TYPE_F32);
1031+
GGML_ASSERT(src1t == GGML_TYPE_F32);
1032+
10211033
const size_t offs = 0;
10221034

10231035
bool bcast_row = false;
10241036

1025-
int64_t nb = ne00;
1037+
int64_t nb = ne00; // used by the "row" kernels
10261038

10271039
id<MTLComputePipelineState> pipeline = nil;
10281040

@@ -1091,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute(
10911103
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10921104
}
10931105
} break;
1106+
case GGML_OP_REPEAT:
1107+
{
1108+
id<MTLComputePipelineState> pipeline;
1109+
1110+
switch (src0t) {
1111+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1112+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1113+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1114+
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1115+
default: GGML_ASSERT(false);
1116+
}
1117+
1118+
[encoder setComputePipelineState:pipeline];
1119+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1120+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1121+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1122+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1123+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1124+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1125+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1126+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1127+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1128+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1129+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1130+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1131+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1132+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1133+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1134+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1135+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1136+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1137+
1138+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1139+
1140+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1141+
} break;
10941142
case GGML_OP_ACC:
10951143
{
10961144
GGML_ASSERT(src0t == GGML_TYPE_F32);
@@ -2573,7 +2621,7 @@ static enum ggml_status ggml_metal_graph_compute(
25732621
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
25742622
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
25752623
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2576-
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2624+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
25772625
default:
25782626
{
25792627
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2586,7 +2634,7 @@ static enum ggml_status ggml_metal_graph_compute(
25862634

25872635
switch (ne00) {
25882636
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2589-
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2637+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
25902638
default:
25912639
{
25922640
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);

ggml-metal.metal

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,53 @@ kernel void kernel_div(
168168
}
169169
}
170170

171+
template<typename T>
172+
kernel void kernel_repeat(
173+
device const char * src0,
174+
device char * dst,
175+
constant int64_t & ne00,
176+
constant int64_t & ne01,
177+
constant int64_t & ne02,
178+
constant int64_t & ne03,
179+
constant uint64_t & nb00,
180+
constant uint64_t & nb01,
181+
constant uint64_t & nb02,
182+
constant uint64_t & nb03,
183+
constant int64_t & ne0,
184+
constant int64_t & ne1,
185+
constant int64_t & ne2,
186+
constant int64_t & ne3,
187+
constant uint64_t & nb0,
188+
constant uint64_t & nb1,
189+
constant uint64_t & nb2,
190+
constant uint64_t & nb3,
191+
uint3 tgpig[[threadgroup_position_in_grid]],
192+
uint3 tpitg[[thread_position_in_threadgroup]],
193+
uint3 ntg[[threads_per_threadgroup]]) {
194+
const int64_t i3 = tgpig.z;
195+
const int64_t i2 = tgpig.y;
196+
const int64_t i1 = tgpig.x;
197+
198+
const int64_t i03 = i3 % ne03;
199+
const int64_t i02 = i2 % ne02;
200+
const int64_t i01 = i1 % ne01;
201+
202+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
203+
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
204+
205+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
206+
const int i00 = i0 % ne00;
207+
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
208+
}
209+
}
210+
211+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
212+
213+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
214+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
215+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
216+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
217+
171218
// assumption: src1 is a row
172219
// broadcast src1 into src0
173220
kernel void kernel_add_row(
@@ -2418,7 +2465,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
24182465
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
24192466
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
24202467
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2421-
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2468+
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
24222469

24232470
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
24242471
kernel void kernel_flash_attn_ext_vec_f16(
@@ -2696,7 +2743,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
26962743
}
26972744

26982745
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2699-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2746+
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
27002747

27012748
kernel void kernel_cpy_f16_f16(
27022749
device const half * src0,

0 commit comments

Comments
 (0)