@@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
20
20
GGML_CUDA_ASSUME (ret < K);
21
21
return ret;
22
22
}
23
+
24
+ __device__ __forceinline__ void load (const int * __restrict__ xs0, const int & stride) {
25
+ #if defined(INT8_MMA_AVAILABLE)
26
+ const int * xs = xs0 + (threadIdx .x %I)*stride + (threadIdx .x /I)*(K/2 );
27
+ asm (" ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
28
+ : " +r" (x[0 ]), " +r" (x[1 ])
29
+ : " l" (xs));
30
+ #else
31
+ #pragma unroll
32
+ for (int l = 0 ; l < ne; ++l) {
33
+ x[l] = xs0[get_i (l)*stride + get_k (l)];
34
+ }
35
+ #endif // defined(INT8_MMA_AVAILABLE)
36
+ }
23
37
};
24
38
25
39
struct mma_int_A_I16K8 {
@@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
42
56
GGML_CUDA_ASSUME (ret < K);
43
57
return ret;
44
58
}
59
+
60
+ __device__ __forceinline__ void load (const int * __restrict__ xs0, const int & stride) {
61
+ #if defined(INT8_MMA_AVAILABLE)
62
+ const int * xs = xs0 + (threadIdx .x %I)*stride + (threadIdx .x /I)*(K/2 );
63
+ asm (" ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
64
+ : " +r" (x[0 ]), " +r" (x[1 ]), " +r" (x[2 ]), " +r" (x[3 ])
65
+ : " l" (xs));
66
+ #else
67
+ #pragma unroll
68
+ for (int l = 0 ; l < ne; ++l) {
69
+ x[l] = xs0[get_i (l)*stride + get_k (l)];
70
+ }
71
+ #endif // defined(INT8_MMA_AVAILABLE)
72
+ }
45
73
};
46
74
47
75
struct mma_int_B_J8K4 {
@@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
64
92
GGML_CUDA_ASSUME (ret < K);
65
93
return ret;
66
94
}
95
+
96
+ __device__ __forceinline__ void load (const int * __restrict__ xs0, const int & stride) {
97
+ #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
98
+ const int * xs = xs0 + (threadIdx .x %J)*stride;
99
+ asm (" ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
100
+ : " +r" (x[0 ])
101
+ : " l" (xs));
102
+ #else
103
+ #pragma unroll
104
+ for (int l = 0 ; l < ne; ++l) {
105
+ x[l] = xs0[get_j (l)*stride + get_k (l)];
106
+ }
107
+ #endif // defined(INT8_MMA_AVAILABLE)
108
+ }
67
109
};
68
110
69
111
struct mma_int_B_J8K8 {
@@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
86
128
GGML_CUDA_ASSUME (ret < K);
87
129
return ret;
88
130
}
131
+
132
+ __device__ __forceinline__ void load (const int * __restrict__ xs0, const int & stride) {
133
+ #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
134
+ const int * xs = xs0 + (threadIdx .x %J)*stride + ((threadIdx .x /J)*(K/2 )) % K;
135
+ asm (" ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
136
+ : " +r" (x[0 ]), " +r" (x[1 ])
137
+ : " l" (xs));
138
+ #else
139
+ #pragma unroll
140
+ for (int l = 0 ; l < ne; ++l) {
141
+ x[l] = xs0[get_j (l)*stride + get_k (l)];
142
+ }
143
+ #endif // defined(INT8_MMA_AVAILABLE)
144
+ }
89
145
};
90
146
91
147
struct mma_int_C_I16J8 {
0 commit comments