@@ -2099,19 +2099,22 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
2099
2099
int32_t * C_cur = TileC0;
2100
2100
int32_t * C_pre = TileC1;
2101
2101
2102
- # define Tile4 ( base ) base
2103
- # define Tile5 ( base ) base + TILE_M * TILE_N
2104
- # define Tile6 ( base ) base + 2 * TILE_M * TILE_N
2105
- # define Tile7 ( base ) base + 3 * TILE_M * TILE_N
2102
+ auto Tile4 = [&]( int32_t * base) { return base; };
2103
+ auto Tile5 = [&]( int32_t * base) { return base + TILE_M * TILE_N; };
2104
+ auto Tile6 = [&]( int32_t * base) { return base + 2 * TILE_M * TILE_N; };
2105
+ auto Tile7 = [&]( int32_t * base) { return base + 3 * TILE_M * TILE_N; };
2106
2106
2107
2107
if (M == 2 * TILE_M) {
2108
2108
// i = 0
2109
+ const char * B_blk0 = B + PACKED_INDEX (0 , 0 , KB, TILE_SIZE);
2110
+ const char * B_blk1 = B + PACKED_INDEX (1 , 0 , KB, TILE_SIZE);
2109
2111
if (need_unpack) {
2110
- unpack_B<TB>(Tile0, B + PACKED_INDEX ( 0 , 0 , KB, TILE_SIZE) );
2112
+ unpack_B<TB>(Tile0, B_blk0 );
2111
2113
_tile_loadd (TMM0, Tile0, TILE_N * VNNI_BLK);
2112
2114
} else {
2113
- _tile_loadd (TMM0, B + PACKED_INDEX ( 0 , 0 , KB, TILE_SIZE) , TILE_N * VNNI_BLK);
2115
+ _tile_loadd (TMM0, B_blk0 , TILE_N * VNNI_BLK);
2114
2116
}
2117
+
2115
2118
_tile_zero (TMM4);
2116
2119
_tile_loadd (TMM2, A[0 ].qs , lda);
2117
2120
_tile_dpbssd (TMM4, TMM2, TMM0);
@@ -2123,11 +2126,12 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
2123
2126
_tile_stored (TMM5, Tile5 (C_pre), TILE_N * sizeof (int32_t ));
2124
2127
2125
2128
if (need_unpack) {
2126
- unpack_B<TB>(Tile1, B + PACKED_INDEX ( 1 , 0 , KB, TILE_SIZE) );
2129
+ unpack_B<TB>(Tile1, B_blk0 );
2127
2130
_tile_loadd (TMM1, Tile1, TILE_N * VNNI_BLK);
2128
2131
} else {
2129
- _tile_loadd (TMM1, B + PACKED_INDEX ( 1 , 0 , KB, TILE_SIZE) , TILE_N * VNNI_BLK);
2132
+ _tile_loadd (TMM1, B_blk1 , TILE_N * VNNI_BLK);
2130
2133
}
2134
+
2131
2135
_tile_zero (TMM6);
2132
2136
_tile_dpbssd (TMM6, TMM2, TMM1);
2133
2137
_tile_stored (TMM6, Tile6 (C_pre), TILE_N * sizeof (int32_t ));
@@ -2139,12 +2143,14 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
2139
2143
for (int i = 1 ; i < KB; ++i) {
2140
2144
// index of previous iter
2141
2145
const int ii = i - 1 ;
2146
+ const char * B_blk0 = B + PACKED_INDEX (0 , i, KB, TILE_SIZE);
2147
+ const char * B_blk1 = B + PACKED_INDEX (1 , i, KB, TILE_SIZE);
2142
2148
GGML_DISPATCH_BOOL (ii > 0 , is_acc, [&] {
2143
2149
if (need_unpack) {
2144
- unpack_B<TB>(Tile0, B + PACKED_INDEX ( 0 , i, KB, TILE_SIZE) );
2150
+ unpack_B<TB>(Tile0, B_blk0 );
2145
2151
_tile_loadd (TMM0, Tile0, TILE_N * VNNI_BLK);
2146
2152
} else {
2147
- _tile_loadd (TMM0, B + PACKED_INDEX ( 0 , i, KB, TILE_SIZE) , TILE_N * VNNI_BLK);
2153
+ _tile_loadd (TMM0, B_blk0 , TILE_N * VNNI_BLK);
2148
2154
}
2149
2155
_tile_zero (TMM4);
2150
2156
_tile_loadd (TMM2, A[i].qs , lda);
@@ -2161,10 +2167,10 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
2161
2167
_tile_stored (TMM5, Tile5 (C_cur), TILE_N * sizeof (int32_t ));
2162
2168
2163
2169
if (need_unpack) {
2164
- unpack_B<TB>(Tile1, B + PACKED_INDEX ( 1 , i, KB, TILE_SIZE) );
2170
+ unpack_B<TB>(Tile1, B_blk1 );
2165
2171
_tile_loadd (TMM1, Tile1, TILE_N * VNNI_BLK);
2166
2172
} else {
2167
- _tile_loadd (TMM1, B + PACKED_INDEX ( 1 , i, KB, TILE_SIZE) , TILE_N * VNNI_BLK);
2173
+ _tile_loadd (TMM1, B_blk1 , TILE_N * VNNI_BLK);
2168
2174
}
2169
2175
_tile_zero (TMM6);
2170
2176
acc_C<TA, TB, is_acc>::apply (C + TILE_N, ldc, Tile6 (C_pre), &A[ii], KB, B + PACKED_INDEX (1 , ii, KB, TILE_SIZE), TILE_M);
@@ -2198,18 +2204,20 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
2198
2204
_tile_zero (TMM7);
2199
2205
}
2200
2206
2207
+ const char * B_blk0 = B + PACKED_INDEX (0 , i, KB, TILE_SIZE);
2208
+ const char * B_blk1 = B + PACKED_INDEX (1 , i, KB, TILE_SIZE);
2201
2209
if (need_unpack) {
2202
- unpack_B<TB>(Tile0, B + PACKED_INDEX ( 0 , i, KB, TILE_SIZE) );
2210
+ unpack_B<TB>(Tile0, B_blk0 );
2203
2211
_tile_loadd (TMM0, Tile0, TILE_N * VNNI_BLK);
2204
2212
} else {
2205
- _tile_loadd (TMM0, B + PACKED_INDEX ( 0 , i, KB, TILE_SIZE) , TILE_N * VNNI_BLK);
2213
+ _tile_loadd (TMM0, B_blk0 , TILE_N * VNNI_BLK);
2206
2214
}
2207
2215
2208
2216
if (need_unpack) {
2209
- unpack_B<TB>(Tile1, B + PACKED_INDEX ( 1 , i, KB, TILE_SIZE) );
2217
+ unpack_B<TB>(Tile1, B_blk1 );
2210
2218
_tile_loadd (TMM1, Tile1, TILE_N * VNNI_BLK);
2211
2219
} else {
2212
- _tile_loadd (TMM1, B + PACKED_INDEX ( 1 , i, KB, TILE_SIZE) , TILE_N * VNNI_BLK);
2220
+ _tile_loadd (TMM1, B_blk1 , TILE_N * VNNI_BLK);
2213
2221
}
2214
2222
2215
2223
if (m0 == TILE_M) {
@@ -2364,6 +2372,8 @@ bool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor * dst) {
2364
2372
const enum ggml_type type = src0->type ;
2365
2373
const int64_t ne0 = dst->ne [0 ];
2366
2374
2375
+ bool is_training = src0->grad || src1->grad ;
2376
+
2367
2377
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
2368
2378
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
2369
2379
bool has_amx_kernels = (type == GGML_TYPE_Q4_0) ||
@@ -2386,6 +2396,7 @@ bool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor * dst) {
2386
2396
return dst->op != GGML_OP_MUL_MAT_ID &&
2387
2397
is_contiguous_2d (src0) &&
2388
2398
is_contiguous_2d (src1) &&
2399
+ !is_training &&
2389
2400
src1->type == GGML_TYPE_F32 &&
2390
2401
has_amx_kernels &&
2391
2402
// out features is 32x
0 commit comments