File tree Expand file tree Collapse file tree 1 file changed +7
-7
lines changed Expand file tree Collapse file tree 1 file changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
2631
2631
const short iv3 = iq3 / rv3;
2632
2632
2633
2633
// load the queries from shared memory into local memory
2634
- half4 mq[D4];
2634
+ float4 mq[D4];
2635
2635
2636
2636
for (short ii = 0 ; ii < D4; ii += NW) {
2637
2637
short i = ii + tiisg;
2638
- mq[i] = sq4[i];
2638
+ mq[i] = (float4) sq4[i];
2639
2639
}
2640
2640
2641
2641
// pointer to the mask
@@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
2661
2661
for (short ii = 0 ; ii < D4; ii += NW) {
2662
2662
const short i = ii + tiisg;
2663
2663
2664
- half4x4 mk;
2665
- mk[0 ] = pk4[i + 0 *(nb11/8 )];
2666
- mk[1 ] = pk4[i + 1 *(nb11/8 )];
2667
- mk[2 ] = pk4[i + 2 *(nb11/8 )];
2668
- mk[3 ] = pk4[i + 3 *(nb11/8 )];
2664
+ float4x4 mk;
2665
+ mk[0 ] = (float4) pk4[i + 0 *(nb11/8 )];
2666
+ mk[1 ] = (float4) pk4[i + 1 *(nb11/8 )];
2667
+ mk[2 ] = (float4) pk4[i + 2 *(nb11/8 )];
2668
+ mk[3 ] = (float4) pk4[i + 3 *(nb11/8 )];
2669
2669
2670
2670
mqk += (float4) (mq[i] * mk);
2671
2671
}
You can’t perform that action at this time.
0 commit comments