Skip to content

Commit 25db74d

Browse files
committed
Enforce dpas template arg checks, Fix dpasw(), Fix in-tree LIT test.
Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent 7e9875a commit 25db74d

File tree

2 files changed

+72
-35
lines changed

2 files changed

+72
-35
lines changed

sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -126,23 +126,43 @@ constexpr int verify_parameters_and_deduce_exec_size() {
126126

127127
if constexpr (APrecision == dpas_argument_type::FP16 ||
128128
BPrecision == dpas_argument_type::FP16) {
129-
static_assert(APrecision == BPrecision &&
130-
__ESIMD_DNS::is_type<T, float, sycl::half>() &&
131-
__ESIMD_DNS::is_type<CT, float, sycl::half>(),
132-
"Unsupported DPAS types! The supported types are:\n"
133-
" Result | C | B | A \n"
134-
" f, hf | f, hf | hf | hf \n");
129+
if constexpr (ExecutionSize == 8) {
130+
static_assert(APrecision == BPrecision &&
131+
__ESIMD_DNS::is_type<T, float>() &&
132+
__ESIMD_DNS::is_type<CT, float>(),
133+
"Unsupported DPAS types! The supported types are:\n"
134+
" Result | C | B | A \n"
135+
" f | f | hf | hf \n");
136+
} else {
137+
static_assert(APrecision == BPrecision &&
138+
__ESIMD_DNS::is_type<T, float, sycl::half>() &&
139+
__ESIMD_DNS::is_type<CT, float, sycl::half>(),
140+
"Unsupported DPAS types! The supported types are:\n"
141+
" Result | C | B | A \n"
142+
" f, hf | f, hf | hf | hf \n");
143+
}
135144
} else if constexpr (APrecision == dpas_argument_type::BF16 ||
136145
BPrecision == dpas_argument_type::BF16) {
137146
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
138-
static_assert(APrecision == BPrecision &&
139-
__ESIMD_DNS::is_type<T, float, bfloat16>() &&
140-
__ESIMD_DNS::is_type<CT, float, bfloat16>(),
141-
"Unsupported DPAS types! The supported types are:\n"
142-
" Result | C | B | A \n"
143-
" f, bf | f, bf | bf | bf \n");
147+
if constexpr (ExecutionSize == 8) {
148+
static_assert(APrecision == BPrecision &&
149+
__ESIMD_DNS::is_type<T, float, bfloat16>() &&
150+
__ESIMD_DNS::is_type<CT, float, bfloat16>(),
151+
"Unsupported DPAS types! The supported types are:\n"
152+
" Result | C | B | A \n"
153+
" f | f | bf | bf \n");
154+
} else {
155+
static_assert(APrecision == BPrecision &&
156+
__ESIMD_DNS::is_type<T, float, bfloat16>() &&
157+
__ESIMD_DNS::is_type<CT, float, bfloat16>(),
158+
"Unsupported DPAS types! The supported types are:\n"
159+
" Result | C | B | A \n"
160+
" f, bf | f, bf | bf | bf \n");
161+
}
144162
} else if constexpr (APrecision == dpas_argument_type::TF32 ||
145163
BPrecision == dpas_argument_type::TF32) {
164+
static_assert(ExecutionSize == 16,
165+
"tf32 type can be used only with ExecutionSize=16");
146166
static_assert(APrecision == BPrecision && std::is_same_v<T, float> &&
147167
std::is_same_v<CT, float>,
148168
"Unsupported DPAS types! The supported types are:\n"
@@ -223,7 +243,7 @@ auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
223243
detail::verify_parameters_and_deduce_exec_size<SystolicDepth, RepeatCount,
224244
T, T, BT, AT, BPrecision,
225245
APrecision, BN, AN>();
226-
// Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) + C(_Mx_N)
246+
// Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
227247
// where:
228248
// _M = RepeatCount;
229249
// _K = SystolicDepth * OpsPerChannel;
@@ -237,8 +257,10 @@ auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
237257

238258
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
239259
((int)APrecision << 8) + (int)BPrecision;
240-
return __esimd_dpas_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
241-
BCasted.data(), ACasted.data());
260+
__ESIMD_NS::simd<T, ResultN> Result =
261+
__esimd_dpas_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
262+
BCasted.data(), ACasted.data());
263+
return Result;
242264
}
243265

244266
/// DPAS (Dot Product Accumulate Systolic)
@@ -283,24 +305,32 @@ template <
283305
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
284306
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
285307
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
286-
int N, int BN, int AN>
287-
__ESIMD_NS::simd<T, N> dpasw(__ESIMD_NS::simd<BT, BN> B,
288-
__ESIMD_NS::simd<AT, AN> A) {
308+
int BN, int AN>
309+
auto dpasw(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
289310

290311
constexpr bool IsDPASW = true;
291-
(void)detail::verify_parameters_and_deduce_exec_size<
312+
constexpr int ExecutionSize = detail::verify_parameters_and_deduce_exec_size<
292313
SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
293314
IsDPASW>();
294315

316+
// Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
317+
// where:
318+
// _M = RepeatCount;
319+
// _K = SystolicDepth * OpsPerChannel;
320+
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
321+
constexpr int ResultN = RepeatCount * ExecutionSize;
322+
295323
constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT));
296324
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT));
297325
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
298326
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();
299327

300328
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
301329
((int)APrecision << 8) + (int)BPrecision;
302-
return __esimd_dpasw_nosrc0<Info, T, int, int, N, BNCasted, ANCasted>(
303-
BCasted.data(), ACasted.data());
330+
__ESIMD_NS::simd<T, ResultN> Result =
331+
__esimd_dpasw_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
332+
BCasted.data(), ACasted.data());
333+
return Result;
304334
}
305335

306336
/// @} sycl_esimd_xmx_systolic_array_api

sycl/test/esimd/dpas.cpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: %clangxx -DESIMD_XE_HPC -O0 -fsycl -c -Xclang -emit-llvm %s -o %t
2-
// RUN: %clangxx -DESIMD_XE_HPC -O0 -fsycl -c -fsycl-device-only -Xclang -emit-llvm %s -o %t
1+
// RUN: %clangxx -O0 -fsycl -c -Xclang -emit-llvm %s -o %t
2+
// RUN: %clangxx -O0 -fsycl -c -fsycl-device-only -Xclang -emit-llvm %s -o %t
33
// RUN: sycl-post-link -split-esimd -lower-esimd -O0 -S %t -o %t.table
44
// RUN: FileCheck %s -input-file=%t_esimd_0.ll
55

@@ -27,13 +27,13 @@ void bar() {
2727
}
2828

2929
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void foo() {
30-
simd<short, 16> A_ACC = 7;
30+
simd<sycl::ext::oneapi::experimental::bfloat16, 16> A_ACC = 7;
3131
simd<int, 128> A_ISRC1 = 0;
3232
simd<int, 8> A_ISRC2 = 0;
3333
simd<float, 16> A_DST =
3434
dpas<argument_type::BF16, argument_type::BF16, float, 8, 1>(
3535
A_ACC, A_ISRC1, A_ISRC2);
36-
// CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 1)
36+
// CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 0)
3737

3838
simd<float, 16> B_ACC = 7;
3939
simd<int, 128> B_ISRC1 = 0;
@@ -49,16 +49,23 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void foo() {
4949
C_ISRC1, C_ISRC2);
5050
// CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}})
5151

52-
simd<float, 16> D_ACC = 7;
53-
simd<int, 128> D_ISRC1 = 0;
54-
simd<int, 8> D_ISRC2 = 0;
55-
simd<float, 16> D_DST = dpasw<argument_type::BF16, argument_type::BF16, 8, 1>(
52+
simd<float, 8> D_ACC =
53+
7; // MxN: 1x8 floats (M=RepeatCount=1, N=ExecutionSize=8)
54+
simd<int, 64> D_ISRC1 =
55+
0; // KxN: 16x8 bf16: (K=SysDepth*OpsPerChan=8*2, N=ExecutionSize=8)
56+
simd<int, 4> D_ISRC2 =
57+
0; // MxK/2: 1x8 bf16: (M=RepeatCount=1, K=SysDepth*OpsPerChan=8*2)
58+
// Result is MxN: 1x8 floats
59+
simd<float, 8> D_DST = dpasw<argument_type::BF16, argument_type::BF16, 8, 1>(
5660
D_ACC, D_ISRC1, D_ISRC2);
57-
// CHECK: call <16 x float> @llvm.genx.dpasw.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}})
61+
// CHECK: call <8 x float> @llvm.genx.dpasw.v8f32.v64i32.v4i32(<8 x float> {{[^,]+}}, <64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 {{[^,]+}})
5862

59-
simd<int, 128> E_ISRC1 = 0;
60-
simd<int, 8> E_ISRC2 = 0;
61-
simd<float, 16> E_DST = dpasw2<argument_type::BF16, argument_type::BF16, 8, 1,
62-
float, int, int, 16>(E_ISRC1, E_ISRC2);
63-
// CHECK: call <16 x float> @llvm.genx.dpasw.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}})
63+
simd<int, 64> E_ISRC1 =
64+
0; // KxN: 16x8 bf16: K=SysDepth*OPC=8*2, N=ExecutionSize=8
65+
simd<int, 4> E_ISRC2 =
66+
0; // MxK/2: 1x16/2 bf16: M=RepeatCount, K=SysDepth*OPC=8*2
67+
// Result is MxN: 1x8 floats
68+
simd<float, 8> E_DST = dpasw2<argument_type::BF16, argument_type::BF16, 8, 1,
69+
float, int, int, 8>(E_ISRC1, E_ISRC2);
70+
// CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 {{[^,]+}})
6471
}

0 commit comments

Comments
 (0)