10
10
// ===----------------------------------------------------------------------===//
11
11
12
12
#include " Mapping.h"
13
+ #include " Interface.h"
13
14
#include " State.h"
14
15
#include " Types.h"
15
16
#include " Utils.h"
@@ -43,6 +44,12 @@ uint32_t getWorkgroupDim(uint32_t group_id, uint32_t grid_size,
43
44
return (r < group_size) ? r : group_size;
44
45
}
45
46
47
+ uint32_t getNumHardwareThreadsInBlock () {
48
+ return getWorkgroupDim (__builtin_amdgcn_workgroup_id_x (),
49
+ __builtin_amdgcn_grid_size_x (),
50
+ __builtin_amdgcn_workgroup_size_x ());
51
+ }
52
+
46
53
LaneMaskTy activemask () { return __builtin_amdgcn_read_exec (); }
47
54
48
55
LaneMaskTy lanemaskLT () {
@@ -67,13 +74,6 @@ uint32_t getThreadIdInWarp() {
67
74
68
75
uint32_t getThreadIdInBlock () { return __builtin_amdgcn_workitem_id_x (); }
69
76
70
- uint32_t getBlockSize () {
71
- // TODO: verify this logic for generic mode.
72
- return getWorkgroupDim (__builtin_amdgcn_workgroup_id_x (),
73
- __builtin_amdgcn_grid_size_x (),
74
- __builtin_amdgcn_workgroup_size_x ());
75
- }
76
-
77
77
uint32_t getKernelSize () { return __builtin_amdgcn_grid_size_x (); }
78
78
79
79
uint32_t getBlockId () { return __builtin_amdgcn_workgroup_id_x (); }
@@ -83,12 +83,8 @@ uint32_t getNumberOfBlocks() {
83
83
__builtin_amdgcn_workgroup_size_x ());
84
84
}
85
85
86
- uint32_t getNumberOfProcessorElements () {
87
- return getBlockSize ();
88
- }
89
-
90
86
uint32_t getWarpId () {
91
- return mapping ::getThreadIdInBlock () / mapping::getWarpSize ();
87
+ return impl ::getThreadIdInBlock () / mapping::getWarpSize ();
92
88
}
93
89
94
90
uint32_t getNumberOfWarpsInBlock () {
@@ -104,6 +100,10 @@ uint32_t getNumberOfWarpsInBlock() {
104
100
#pragma omp begin declare variant match( \
105
101
device = {arch (nvptx, nvptx64)}, implementation = {extension (match_any)})
106
102
103
+ uint32_t getNumHardwareThreadsInBlock () {
104
+ return __nvvm_read_ptx_sreg_ntid_x ();
105
+ }
106
+
107
107
constexpr const llvm::omp::GV &getGridValue () {
108
108
return llvm::omp::NVPTXGridValues;
109
109
}
@@ -126,29 +126,23 @@ LaneMaskTy lanemaskGT() {
126
126
return Res;
127
127
}
128
128
129
- uint32_t getThreadIdInWarp () {
130
- return mapping::getThreadIdInBlock () & (mapping::getWarpSize () - 1 );
131
- }
132
-
133
129
uint32_t getThreadIdInBlock () { return __nvvm_read_ptx_sreg_tid_x (); }
134
130
135
- uint32_t getBlockSize () {
136
- return __nvvm_read_ptx_sreg_ntid_x () -
137
- (!mapping::isSPMDMode () * mapping::getWarpSize ());
131
+ uint32_t getThreadIdInWarp () {
132
+ return impl::getThreadIdInBlock () & (mapping::getWarpSize () - 1 );
138
133
}
139
134
140
- uint32_t getKernelSize () { return __nvvm_read_ptx_sreg_nctaid_x (); }
135
+ uint32_t getKernelSize () {
136
+ return __nvvm_read_ptx_sreg_nctaid_x () *
137
+ mapping::getNumberOfProcessorElements ();
138
+ }
141
139
142
140
uint32_t getBlockId () { return __nvvm_read_ptx_sreg_ctaid_x (); }
143
141
144
142
uint32_t getNumberOfBlocks () { return __nvvm_read_ptx_sreg_nctaid_x (); }
145
143
146
- uint32_t getNumberOfProcessorElements () {
147
- return __nvvm_read_ptx_sreg_ntid_x ();
148
- }
149
-
150
144
uint32_t getWarpId () {
151
- return mapping ::getThreadIdInBlock () / mapping::getWarpSize ();
145
+ return impl ::getThreadIdInBlock () / mapping::getWarpSize ();
152
146
}
153
147
154
148
uint32_t getNumberOfWarpsInBlock () {
@@ -164,6 +158,10 @@ uint32_t getWarpSize() { return getGridValue().GV_Warp_Size; }
164
158
} // namespace impl
165
159
} // namespace _OMP
166
160
161
+ // / We have to be deliberate about the distinction of `mapping::` and `impl::`
162
+ // / below to avoid repeating assumptions or including irrelevant ones.
163
+ // /{
164
+
167
165
static bool isInLastWarp () {
168
166
uint32_t MainTId = (mapping::getNumberOfProcessorElements () - 1 ) &
169
167
~(mapping::getWarpSize () - 1 );
@@ -200,30 +198,60 @@ LaneMaskTy mapping::lanemaskLT() { return impl::lanemaskLT(); }
200
198
201
199
LaneMaskTy mapping::lanemaskGT () { return impl::lanemaskGT (); }
202
200
203
- uint32_t mapping::getThreadIdInWarp () { return impl::getThreadIdInWarp (); }
201
+ uint32_t mapping::getThreadIdInWarp () {
202
+ uint32_t ThreadIdInWarp = impl::getThreadIdInWarp ();
203
+ ASSERT (ThreadIdInWarp < impl::getWarpSize ());
204
+ return ThreadIdInWarp;
205
+ }
206
+
207
+ uint32_t mapping::getThreadIdInBlock () {
208
+ uint32_t ThreadIdInBlock = impl::getThreadIdInBlock ();
209
+ ASSERT (ThreadIdInBlock < impl::getNumHardwareThreadsInBlock ());
210
+ return ThreadIdInBlock;
211
+ }
204
212
205
- uint32_t mapping::getThreadIdInBlock () { return impl::getThreadIdInBlock (); }
213
+ uint32_t mapping::getWarpSize () { return impl::getWarpSize (); }
206
214
207
- uint32_t mapping::getBlockSize () { return impl::getBlockSize (); }
215
+ uint32_t mapping::getBlockSize () {
216
+ uint32_t BlockSize = mapping::getNumberOfProcessorElements () -
217
+ (!mapping::isSPMDMode () * impl::getWarpSize ());
218
+ return BlockSize;
219
+ }
208
220
209
221
uint32_t mapping::getKernelSize () { return impl::getKernelSize (); }
210
222
211
- uint32_t mapping::getBlockId () { return impl::getBlockId (); }
212
-
213
- uint32_t mapping::getNumberOfBlocks () { return impl::getNumberOfBlocks (); }
223
+ uint32_t mapping::getWarpId () {
224
+ uint32_t WarpID = impl::getWarpId ();
225
+ ASSERT (WarpID < impl::getNumberOfWarpsInBlock ());
226
+ return WarpID;
227
+ }
214
228
215
- uint32_t mapping::getNumberOfProcessorElements () {
216
- return impl::getNumberOfProcessorElements ();
229
+ uint32_t mapping::getBlockId () {
230
+ uint32_t BlockId = impl::getBlockId ();
231
+ ASSERT (BlockId < impl::getNumberOfBlocks ());
232
+ return BlockId;
217
233
}
218
234
219
- uint32_t mapping::getWarpId () { return impl::getWarpId (); }
235
+ uint32_t mapping::getNumberOfWarpsInBlock () {
236
+ uint32_t NumberOfWarpsInBlocks = impl::getNumberOfWarpsInBlock ();
237
+ ASSERT (impl::getWarpId () < NumberOfWarpsInBlocks);
238
+ return NumberOfWarpsInBlocks;
239
+ }
220
240
221
- uint32_t mapping::getWarpSize () { return impl::getWarpSize (); }
241
+ uint32_t mapping::getNumberOfBlocks () {
242
+ uint32_t NumberOfBlocks = impl::getNumberOfBlocks ();
243
+ ASSERT (impl::getBlockId () < NumberOfBlocks);
244
+ return NumberOfBlocks;
245
+ }
222
246
223
- uint32_t mapping::getNumberOfWarpsInBlock () {
224
- return impl::getNumberOfWarpsInBlock ();
247
+ uint32_t mapping::getNumberOfProcessorElements () {
248
+ uint32_t NumberOfProcessorElements = impl::getNumHardwareThreadsInBlock ();
249
+ ASSERT (impl::getThreadIdInBlock () < NumberOfProcessorElements);
250
+ return NumberOfProcessorElements;
225
251
}
226
252
253
+ // /}
254
+
227
255
// / Execution mode
228
256
// /
229
257
// /{
@@ -247,7 +275,7 @@ __attribute__((noinline)) uint32_t __kmpc_get_hardware_thread_id_in_block() {
247
275
248
276
__attribute__ ((noinline)) uint32_t __kmpc_get_hardware_num_threads_in_block() {
249
277
FunctionTracingRAII ();
250
- return mapping::getNumberOfProcessorElements ();
278
+ return impl::getNumHardwareThreadsInBlock ();
251
279
}
252
280
}
253
281
#pragma omp end declare target
0 commit comments