@@ -802,15 +802,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
802
802
if (op->src [0 ]->ne [0 ] == 256 ) {
803
803
return false ;
804
804
}
805
- {
806
- float logit_softcap;
807
-
808
- memcpy (&logit_softcap, ((const float *) op->op_params ) + 2 , sizeof (logit_softcap));
809
-
810
- if (logit_softcap != 0 .0f ) {
811
- return false ;
812
- }
813
- }
814
805
return ctx->support_simdgroup_mm ; // TODO: over-restricted for vec-kernels
815
806
case GGML_OP_MUL_MAT:
816
807
case GGML_OP_MUL_MAT_ID:
@@ -2633,9 +2624,14 @@ static enum ggml_status ggml_metal_graph_compute(
2633
2624
2634
2625
float scale;
2635
2626
float max_bias;
2627
+ float logit_softcap;
2628
+ memcpy (&scale, ((int32_t *) dst->op_params ) + 0 , sizeof (scale));
2629
+ memcpy (&max_bias, ((int32_t *) dst->op_params ) + 1 , sizeof (max_bias));
2630
+ memcpy (&logit_softcap, ((int32_t *) dst->op_params ) + 2 , sizeof (logit_softcap));
2636
2631
2637
- memcpy (&scale, ((int32_t *) dst->op_params ) + 0 , sizeof (scale));
2638
- memcpy (&max_bias, ((int32_t *) dst->op_params ) + 1 , sizeof (max_bias));
2632
+ if (logit_softcap != 0 .0f ) {
2633
+ scale /= logit_softcap;
2634
+ }
2639
2635
2640
2636
const uint32_t n_head = src0->ne [2 ];
2641
2637
const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
@@ -2686,30 +2682,31 @@ static enum ggml_status ggml_metal_graph_compute(
2686
2682
} else {
2687
2683
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 3 ];
2688
2684
}
2689
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
2690
- [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 5 ];
2691
- [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 6 ];
2692
- [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 7 ];
2693
- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 8 ];
2694
- [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 9 ];
2695
- [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 10 ];
2696
- [encoder setBytes: &ne11 length: sizeof ( int64_t ) atIndex: 11 ];
2697
- [encoder setBytes: &ne12 length: sizeof ( int64_t ) atIndex: 12 ];
2698
- [encoder setBytes: &ne13 length: sizeof ( int64_t ) atIndex: 13 ];
2699
- [encoder setBytes: &nb11 length: sizeof (uint64_t ) atIndex: 14 ];
2700
- [encoder setBytes: &nb12 length: sizeof (uint64_t ) atIndex: 15 ];
2701
- [encoder setBytes: &nb13 length: sizeof (uint64_t ) atIndex: 16 ];
2702
- [encoder setBytes: &nb21 length: sizeof (uint64_t ) atIndex: 17 ];
2703
- [encoder setBytes: &nb22 length: sizeof (uint64_t ) atIndex: 18 ];
2704
- [encoder setBytes: &nb23 length: sizeof (uint64_t ) atIndex: 19 ];
2705
- [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 20 ];
2706
- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 21 ];
2707
- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 22 ];
2708
- [encoder setBytes: &scale length: sizeof ( float ) atIndex: 23 ];
2709
- [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 24 ];
2710
- [encoder setBytes: &m0 length: sizeof (m0) atIndex: 25 ];
2711
- [encoder setBytes: &m1 length: sizeof (m1) atIndex: 26 ];
2712
- [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 27 ];
2685
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
2686
+ [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 5 ];
2687
+ [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 6 ];
2688
+ [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 7 ];
2689
+ [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 8 ];
2690
+ [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 9 ];
2691
+ [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 10 ];
2692
+ [encoder setBytes: &ne11 length: sizeof ( int64_t ) atIndex: 11 ];
2693
+ [encoder setBytes: &ne12 length: sizeof ( int64_t ) atIndex: 12 ];
2694
+ [encoder setBytes: &ne13 length: sizeof ( int64_t ) atIndex: 13 ];
2695
+ [encoder setBytes: &nb11 length: sizeof (uint64_t ) atIndex: 14 ];
2696
+ [encoder setBytes: &nb12 length: sizeof (uint64_t ) atIndex: 15 ];
2697
+ [encoder setBytes: &nb13 length: sizeof (uint64_t ) atIndex: 16 ];
2698
+ [encoder setBytes: &nb21 length: sizeof (uint64_t ) atIndex: 17 ];
2699
+ [encoder setBytes: &nb22 length: sizeof (uint64_t ) atIndex: 18 ];
2700
+ [encoder setBytes: &nb23 length: sizeof (uint64_t ) atIndex: 19 ];
2701
+ [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 20 ];
2702
+ [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 21 ];
2703
+ [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 22 ];
2704
+ [encoder setBytes: &scale length: sizeof ( float ) atIndex: 23 ];
2705
+ [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 24 ];
2706
+ [encoder setBytes: &m0 length: sizeof (m0) atIndex: 25 ];
2707
+ [encoder setBytes: &m1 length: sizeof (m1) atIndex: 26 ];
2708
+ [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 27 ];
2709
+ [encoder setBytes: &logit_softcap length: sizeof (logit_softcap) atIndex: 28 ];
2713
2710
2714
2711
if (!use_vec_kernel) {
2715
2712
// half8x8 kernel
0 commit comments