@@ -1007,17 +1007,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
1007
1007
case GGML_OP_ADD:
1008
1008
case GGML_OP_SCALE:
1009
1009
case GGML_OP_MUL:
1010
- return true ;
1010
+ return op-> src [ 0 ]-> type == GGML_TYPE_F32 ;
1011
1011
case GGML_OP_UNARY:
1012
1012
switch (ggml_get_unary_op (op)) {
1013
1013
case GGML_UNARY_OP_GELU:
1014
1014
case GGML_UNARY_OP_SILU:
1015
1015
case GGML_UNARY_OP_RELU:
1016
- return ggml_is_contiguous (op->src [0 ]);
1016
+ return ggml_is_contiguous (op->src [0 ]) && op-> src [ 0 ]-> type == GGML_TYPE_F32 ;
1017
1017
default :
1018
1018
return false ;
1019
1019
}
1020
1020
case GGML_OP_CLAMP:
1021
+ return op->src [0 ]->type == GGML_TYPE_F32;
1021
1022
case GGML_OP_SOFT_MAX:
1022
1023
case GGML_OP_NORM:
1023
1024
case GGML_OP_RMS_NORM:
@@ -2573,26 +2574,33 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
2573
2574
memcpy (&eps, dst->op_params , sizeof (float ));
2574
2575
2575
2576
const int ne00 = src0 ? src0->ne [0 ] : 0 ;
2576
- const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
2577
+ const int ne01 = src0 ? src0->ne [1 ] : 0 ;
2578
+ const int ne02 = src0 ? src0->ne [2 ] : 0 ;
2579
+ const int ne03 = src0 ? src0->ne [3 ] : 0 ;
2577
2580
2578
- GGML_ASSERT (ggml_is_contiguous_1 (src0));
2581
+ const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
2582
+ const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
2583
+ const cl_ulong nb03 = src0 ? src0->nb [3 ] : 0 ;
2579
2584
2580
2585
const int nth = MIN (64 , ne00);
2581
2586
2582
2587
cl_kernel kernel = backend_ctx->kernel_norm ;
2583
2588
2584
- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2585
- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2586
- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2587
- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2588
- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2589
- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &nb01));
2590
- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (float ), &eps));
2591
- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (float )*nth, NULL ));
2592
-
2593
- const int64_t nrows = ggml_nrows (src0);
2589
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2590
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2591
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2592
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2593
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2594
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &ne01));
2595
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne02));
2596
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne03));
2597
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb01));
2598
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb02));
2599
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb03));
2600
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &eps));
2601
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float )*nth, NULL ));
2594
2602
2595
- size_t global_work_size[] = {(size_t )nrows *nth, 1 , 1 };
2603
+ size_t global_work_size[] = {(size_t )ne01 *nth, ( size_t )ne02, ( size_t )ne03 };
2596
2604
size_t local_work_size[] = {(size_t )nth, 1 , 1 };
2597
2605
2598
2606
#ifdef GGML_OPENCL_PROFILING
@@ -2630,16 +2638,19 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
2630
2638
memcpy (&eps, dst->op_params , sizeof (float ));
2631
2639
2632
2640
const int ne00 = src0 ? src0->ne [0 ] : 0 ;
2641
+ const int ne01 = src0 ? src0->ne [1 ] : 0 ;
2642
+ const int ne02 = src0 ? src0->ne [2 ] : 0 ;
2643
+ const int ne03 = src0 ? src0->ne [3 ] : 0 ;
2644
+
2633
2645
const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
2646
+ const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
2647
+ const cl_ulong nb03 = src0 ? src0->nb [3 ] : 0 ;
2634
2648
2635
2649
GGML_ASSERT (ne00 % 4 == 0 );
2636
- GGML_ASSERT (ggml_is_contiguous_1 (src0));
2637
2650
2638
2651
const int nth = MIN (64 , ne00);
2639
2652
2640
- const int64_t nrows = ggml_nrows (src0);
2641
-
2642
- size_t global_work_size[] = {(size_t )nrows*nth, 1 , 1 };
2653
+ size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
2643
2654
size_t local_work_size[] = {(size_t )nth, 1 , 1 };
2644
2655
2645
2656
cl_kernel kernel = backend_ctx->kernel_rms_norm ;
@@ -2654,15 +2665,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
2654
2665
sizeof (local_work_size), local_work_size,
2655
2666
sizeof (size_t ), &sgs, NULL ));
2656
2667
2657
- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2658
- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2659
- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2660
- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2661
- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2662
- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &nb01));
2663
- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (float ), &eps));
2668
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2669
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2670
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2671
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2672
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2673
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &ne01));
2674
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne02));
2675
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne03));
2676
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb01));
2677
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb02));
2678
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb03));
2679
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &eps));
2664
2680
// This is local memory - the size depends on subgroup size.
2665
- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (float )*nth/sgs, NULL ));
2681
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float )*nth/sgs, NULL ));
2666
2682
2667
2683
#ifdef GGML_OPENCL_PROFILING
2668
2684
cl_event evt;
0 commit comments