Skip to content

Commit 1b496a7

Browse files
[SYCL] Fixed minor bug when enabling FP16 for non intel targets (#6464)
* moved INTEL_MKL guard from gemm_impl to gemm (wrapper) * Update ggml-sycl.cpp Co-authored-by: AidanBeltonS <[email protected]> --------- Co-authored-by: AidanBeltonS <[email protected]>
1 parent a307375 commit 1b496a7

File tree

1 file changed

+2
-19
lines changed

1 file changed

+2
-19
lines changed

ggml-sycl.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,24 +1664,6 @@ namespace dpct
16641664
const void *alpha, const void *a, int lda, const void *b,
16651665
int ldb, const void *beta, void *c, int ldc)
16661666
{
1667-
#ifndef __INTEL_MKL__
1668-
GGML_UNUSED(q);
1669-
GGML_UNUSED(a_trans);
1670-
GGML_UNUSED(b_trans);
1671-
GGML_UNUSED(m);
1672-
GGML_UNUSED(n);
1673-
GGML_UNUSED(k);
1674-
GGML_UNUSED(alpha);
1675-
GGML_UNUSED(a);
1676-
GGML_UNUSED(lda);
1677-
GGML_UNUSED(b);
1678-
GGML_UNUSED(ldb);
1679-
GGML_UNUSED(beta);
1680-
GGML_UNUSED(c);
1681-
GGML_UNUSED(ldc);
1682-
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
1683-
"Project does not support this API.");
1684-
#else
16851667
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
16861668
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
16871669
auto data_a = get_memory<const Ta>(a);
@@ -1690,7 +1672,6 @@ namespace dpct
16901672
oneapi::mkl::blas::column_major::gemm(
16911673
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
16921674
data_b, ldb, beta_value, data_c, ldc);
1693-
#endif
16941675
}
16951676

16961677
template <typename VecT, class BinaryOperation, class = void>
@@ -2330,6 +2311,7 @@ namespace dpct
23302311
lda, b, ldb, beta, c, ldc);
23312312
break;
23322313
}
2314+
#ifdef __INTEL_MKL__
23332315
case detail::get_type_combination_id(
23342316
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
23352317
library_data_t::real_float, library_data_t::real_float):
@@ -2391,6 +2373,7 @@ namespace dpct
23912373
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
23922374
break;
23932375
}
2376+
#endif // __INTEL_MKL__
23942377
default:
23952378
throw std::runtime_error("the combination of data type is unsupported");
23962379
}

0 commit comments

Comments
 (0)