Skip to content

Commit 3183bcd

Browse files
authored
Fix mkldnn_matmul error on AArch64 (pytorch#114851)
Fixes pytorch#110149 Cherry pick pytorch#110150. This is a bug fix against 2.1 release
1 parent b5a89bb commit 3183bcd

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,12 +1483,14 @@ static void addmm_impl_cpu_(
14831483
// it is faster to call oneDNN matrix multiplication primitive with RHS*LHS
14841484
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also
14851485
// additionally have support for running kernel with BF16 instructions
1486-
bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1487-
if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
1488-
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
1489-
// We have dispatched to ACL GEMM for single precision float
1490-
// so do not need to dispatch to BLAS GEMM below
1491-
dispatched = true;
1486+
if (transpose_c) {
1487+
bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1488+
if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
1489+
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
1490+
// We have dispatched to ACL GEMM for single precision float
1491+
// so do not need to dispatch to BLAS GEMM below
1492+
dispatched = true;
1493+
}
14921494
}
14931495
#endif
14941496

0 commit comments

Comments
 (0)