Skip to content

Commit ec97c57

Browse files
authored
[SYCL][CUDA][MATRIX] Remove using namespace experimental from headers (#5217)
Removed `using namespace experimental` - replaced with fully qualified names. This PR fixes #5213. I chose to use fully qualified names instead of moving local `detail` to `sycl::ext::oneapi::experimental::matrix::detail`, although this second option seems the most sensible to me: however I didn't move `detail` at the moment for consistency with the intel matrix extension namespace use (which also coincides with the standard practice of wider dpc++ as far as I can tell). Please be aware that `using namespace experimental;` is also used on line 201 of matrix-aot-amx.hpp, which could lead to similar problems in the future. Signed-off-by: jack.kirk <[email protected]>
1 parent a55a713 commit ec97c57

File tree

2 files changed

+156
-85
lines changed

2 files changed

+156
-85
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 146 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -51,34 +51,45 @@ struct joint_matrix<
5151
} // namespace experimental::matrix
5252

5353
namespace detail {
54-
using namespace experimental;
5554

56-
template <typename T, matrix::matrix_use MT, size_t NumRows, size_t NumCols,
57-
matrix::matrix_layout Layout, access::address_space Space,
58-
typename Cond = void>
55+
template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use MT,
56+
size_t NumRows, size_t NumCols,
57+
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
58+
access::address_space Space, typename Cond = void>
5959
struct joint_matrix_load_impl {
60-
void load(matrix::joint_matrix<T, MT, NumRows, NumCols, Layout> &res,
60+
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
61+
T, MT, NumRows, NumCols, Layout> &res,
6162
multi_ptr<T, Space> src, size_t stride);
6263
};
6364

64-
template <matrix::matrix_layout Layout> constexpr int get_layout_id();
65+
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout>
66+
constexpr int get_layout_id();
6567

66-
template <> constexpr int get_layout_id<matrix::matrix_layout::row_major>() {
68+
template <>
69+
constexpr int get_layout_id<
70+
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
6771
return 0;
6872
}
6973

70-
template <> constexpr int get_layout_id<matrix::matrix_layout::col_major>() {
74+
template <>
75+
constexpr int get_layout_id<
76+
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
7177
return 1;
7278
}
7379

74-
template <matrix::matrix_layout Layout, access::address_space Space>
80+
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
81+
access::address_space Space>
7582
struct joint_matrix_load_impl<
76-
double, matrix::matrix_use::a, 8, 4, Layout, Space,
77-
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
78-
Layout == matrix::matrix_layout::col_major>> {
79-
void
80-
load(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, Layout> &res,
81-
multi_ptr<double, Space> src, size_t stride) {
83+
double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8, 4,
84+
Layout, Space,
85+
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
86+
matrix::matrix_layout::row_major ||
87+
Layout == sycl::ext::oneapi::experimental::
88+
matrix::matrix_layout::col_major>> {
89+
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
90+
double, sycl::ext::oneapi::experimental::matrix::matrix_use::a,
91+
8, 4, Layout> &res,
92+
multi_ptr<double, Space> src, size_t stride) {
8293

8394
#ifdef __NVPTX__
8495
#ifdef __SYCL_DEVICE_ONLY__
@@ -88,14 +99,19 @@ struct joint_matrix_load_impl<
8899
}
89100
};
90101

91-
template <matrix::matrix_layout Layout, access::address_space Space>
102+
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
103+
access::address_space Space>
92104
struct joint_matrix_load_impl<
93-
double, matrix::matrix_use::b, 4, 8, Layout, Space,
94-
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
95-
Layout == matrix::matrix_layout::col_major>> {
96-
void
97-
load(matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, Layout> &res,
98-
multi_ptr<double, Space> src, size_t stride) {
105+
double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4, 8,
106+
Layout, Space,
107+
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
108+
matrix::matrix_layout::row_major ||
109+
Layout == sycl::ext::oneapi::experimental::
110+
matrix::matrix_layout::col_major>> {
111+
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
112+
double, sycl::ext::oneapi::experimental::matrix::matrix_use::b,
113+
4, 8, Layout> &res,
114+
multi_ptr<double, Space> src, size_t stride) {
99115
#ifdef __NVPTX__
100116
#ifdef __SYCL_DEVICE_ONLY__
101117
__dmma_m8n8k4_ld_b(res.data, src.get(), stride, get_layout_id<Layout>());
@@ -104,14 +120,21 @@ struct joint_matrix_load_impl<
104120
}
105121
};
106122

107-
template <matrix::matrix_layout Layout, access::address_space Space>
123+
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
124+
access::address_space Space>
108125
struct joint_matrix_load_impl<
109-
double, matrix::matrix_use::accumulator, 8, 8, Layout, Space,
110-
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
111-
Layout == matrix::matrix_layout::col_major>> {
112-
void load(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
113-
Layout> &res,
114-
multi_ptr<double, Space> src, size_t stride) {
126+
double, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8,
127+
8, Layout, Space,
128+
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
129+
matrix::matrix_layout::row_major ||
130+
Layout == sycl::ext::oneapi::experimental::
131+
matrix::matrix_layout::col_major>> {
132+
void
133+
load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
134+
double,
135+
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8,
136+
8, Layout> &res,
137+
multi_ptr<double, Space> src, size_t stride) {
115138

116139
#ifdef __NVPTX__
117140
#ifdef __SYCL_DEVICE_ONLY__
@@ -122,22 +145,30 @@ struct joint_matrix_load_impl<
122145
};
123146

124147
template <typename T, size_t NumRows, size_t NumCols,
125-
matrix::matrix_layout Layout, access::address_space Space,
126-
typename Cond = void>
148+
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
149+
access::address_space Space, typename Cond = void>
127150
struct joint_matrix_store_impl {
128-
void store(matrix::joint_matrix<T, matrix::matrix_use::accumulator, NumRows,
129-
NumCols, Layout> &src,
130-
multi_ptr<T, Space> dst, size_t stride);
151+
void
152+
store(sycl::ext::oneapi::experimental::matrix::joint_matrix<
153+
T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
154+
NumRows, NumCols, Layout> &src,
155+
multi_ptr<T, Space> dst, size_t stride);
131156
};
132157

133-
template <matrix::matrix_layout Layout, access::address_space Space>
158+
template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
159+
access::address_space Space>
134160
struct joint_matrix_store_impl<
135161
double, 8, 8, Layout, Space,
136-
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
137-
Layout == matrix::matrix_layout::col_major>> {
138-
void store(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
139-
Layout> &src,
140-
multi_ptr<double, Space> dst, size_t stride) {
162+
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
163+
matrix::matrix_layout::row_major ||
164+
Layout == sycl::ext::oneapi::experimental::
165+
matrix::matrix_layout::col_major>> {
166+
void
167+
store(sycl::ext::oneapi::experimental::matrix::joint_matrix<
168+
double,
169+
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8,
170+
8, Layout> &src,
171+
multi_ptr<double, Space> dst, size_t stride) {
141172

142173
#ifdef __NVPTX__
143174
#ifdef __SYCL_DEVICE_ONLY__
@@ -149,60 +180,98 @@ struct joint_matrix_store_impl<
149180
};
150181

151182
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
152-
matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
153-
matrix::matrix_layout LayoutC, typename Cond = void>
183+
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
184+
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
185+
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
186+
typename Cond = void>
154187
struct joint_matrix_mad_impl {
155-
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
156-
mad(matrix::joint_matrix<T1, matrix::matrix_use::a, M, K, LayoutA> A,
157-
matrix::joint_matrix<T1, matrix::matrix_use::b, K, N, LayoutB> B,
158-
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
188+
sycl::ext::oneapi::experimental::matrix::joint_matrix<
189+
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
190+
N, LayoutC>
191+
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
192+
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
193+
LayoutA>
194+
A,
195+
sycl::ext::oneapi::experimental::matrix::joint_matrix<
196+
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
197+
LayoutB>
198+
B,
199+
sycl::ext::oneapi::experimental::matrix::joint_matrix<
200+
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
201+
M, N, LayoutC>
159202
C);
160203
};
161204

162-
template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB>
205+
template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
206+
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB>
163207
constexpr int get_layout_pair_id();
164208

165209
template <>
166-
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
167-
matrix::matrix_layout::row_major>() {
210+
constexpr int get_layout_pair_id<
211+
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
212+
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
168213
return 0;
169214
}
170215

171216
template <>
172-
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
173-
matrix::matrix_layout::col_major>() {
217+
constexpr int get_layout_pair_id<
218+
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
219+
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
174220
return 1;
175221
}
176222

177223
template <>
178-
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
179-
matrix::matrix_layout::row_major>() {
224+
constexpr int get_layout_pair_id<
225+
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
226+
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
180227
return 2;
181228
}
182229

183230
template <>
184-
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
185-
matrix::matrix_layout::col_major>() {
231+
constexpr int get_layout_pair_id<
232+
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
233+
sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
186234
return 3;
187235
}
188236

189-
template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
190-
matrix::matrix_layout LayoutC>
237+
template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
238+
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
239+
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC>
191240
struct joint_matrix_mad_impl<
192241
double, double, 8, 4, 8, LayoutA, LayoutB, LayoutC,
193-
typename std::enable_if_t<(LayoutA == matrix::matrix_layout::row_major ||
194-
LayoutA == matrix::matrix_layout::col_major) &&
195-
(LayoutB == matrix::matrix_layout::row_major ||
196-
LayoutB == matrix::matrix_layout::col_major) &&
197-
(LayoutC == matrix::matrix_layout::row_major ||
198-
LayoutC == matrix::matrix_layout::col_major)>> {
199-
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
200-
mad(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, LayoutA> A,
201-
matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, LayoutB> B,
202-
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
203-
LayoutC>
242+
typename std::enable_if_t<
243+
(LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
244+
row_major ||
245+
LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
246+
col_major) &&
247+
(LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
248+
row_major ||
249+
LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
250+
col_major) &&
251+
(LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
252+
row_major ||
253+
LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
254+
col_major)>> {
255+
sycl::ext::oneapi::experimental::matrix::joint_matrix<
256+
double, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
257+
8, 8, LayoutC>
258+
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
259+
double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8, 4,
260+
LayoutA>
261+
A,
262+
sycl::ext::oneapi::experimental::matrix::joint_matrix<
263+
double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4, 8,
264+
LayoutB>
265+
B,
266+
sycl::ext::oneapi::experimental::matrix::joint_matrix<
267+
double,
268+
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8,
269+
8, LayoutC>
204270
C) {
205-
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
271+
sycl::ext::oneapi::experimental::matrix::joint_matrix<
272+
double,
273+
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, 8,
274+
LayoutC>
206275
D;
207276

208277
#ifdef __NVPTX__
@@ -225,8 +294,9 @@ template <typename Group, typename T, matrix_use MT, size_t NumRows,
225294
void joint_matrix_load(
226295
Group sg, joint_matrix<T, MT, NumRows, NumCols, Layout, Group> &res,
227296
multi_ptr<T, Space> src, size_t stride) {
228-
detail::joint_matrix_load_impl<T, MT, NumRows, NumCols, Layout, Space>{}.load(
229-
res, src, stride);
297+
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, MT, NumRows, NumCols,
298+
Layout, Space>{}
299+
.load(res, src, stride);
230300
}
231301

232302
template <typename Group, typename T, size_t NumRows, size_t NumCols,
@@ -235,8 +305,9 @@ void joint_matrix_store(Group sg,
235305
joint_matrix<T, matrix_use::accumulator, NumRows,
236306
NumCols, Layout, Group> &src,
237307
multi_ptr<T, Space> dst, size_t stride) {
238-
detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout, Space>{}.store(
239-
src, dst, stride);
308+
sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
309+
Layout, Space>{}
310+
.store(src, dst, stride);
240311
}
241312

242313
template <typename Group, typename T1, typename T2, std::size_t M,
@@ -247,8 +318,8 @@ joint_matrix_mad(
247318
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
248319
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
249320
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
250-
return detail::joint_matrix_mad_impl<T1, T2, M, K, N, LayoutA, LayoutB,
251-
LayoutC>{}
321+
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
322+
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
252323
.mad(A, B, C);
253324
}
254325

sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ int main() {
5050
joint_matrix<double, matrix_use::b, K, N, matrix_layout::row_major>
5151
sub_b;
5252

53-
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
53+
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %_arg_, i32 8) #{{.*}}
5454
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
55-
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 4) #{{.*}}
55+
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %_arg_4, i32 4) #{{.*}}
5656
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
57-
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 8) #{{.*}}
57+
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %_arg_9, i32 8) #{{.*}}
5858
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
59-
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %11, double %12, double %9, double %10) #{{.*}}
59+
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %3, double %4, double %1, double %2) #{{.*}}
6060
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
61-
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
61+
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %_arg_14, double %6, double %7, i32 8) #{{.*}}
6262
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
6363
});
6464
});
@@ -84,15 +84,15 @@ int main() {
8484
joint_matrix<double, matrix_use::b, K, N, matrix_layout::col_major>
8585
sub_b;
8686

87-
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
87+
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %_arg_, i32 8) #{{.*}}
8888
joint_matrix_load(sg, sub_c, accC.get_pointer(), M);
89-
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 8) #{{.*}}
89+
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %_arg_4, i32 8) #{{.*}}
9090
joint_matrix_load(sg, sub_a, accA.get_pointer(), M);
91-
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 4) #{{.*}}
91+
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %_arg_9, i32 4) #{{.*}}
9292
joint_matrix_load(sg, sub_b, accB.get_pointer(), K);
93-
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %11, double %12, double %9, double %10) #{{.*}}
93+
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %3, double %4, double %1, double %2) #{{.*}}
9494
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
95-
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
95+
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %_arg_14, double %6, double %7, i32 8) #{{.*}}
9696
joint_matrix_store(sg, sub_c, accD.get_pointer(), M);
9797
});
9898
});

0 commit comments

Comments
 (0)