Skip to content

Commit 93bebdc

Browse files
committed
[OpenMP][NFCI] Cleanup new device RT mapping interface
Minimize the `impl` interface and clean up some uses of mapping functions. Reviewed By: jhuber6 Differential Revision: https://reviews.llvm.org/D112154
1 parent dec15d9 commit 93bebdc

File tree

1 file changed

+66
-38
lines changed

1 file changed

+66
-38
lines changed

openmp/libomptarget/DeviceRTL/src/Mapping.cpp

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//===----------------------------------------------------------------------===//
1111

1212
#include "Mapping.h"
13+
#include "Interface.h"
1314
#include "State.h"
1415
#include "Types.h"
1516
#include "Utils.h"
@@ -43,6 +44,12 @@ uint32_t getWorkgroupDim(uint32_t group_id, uint32_t grid_size,
4344
return (r < group_size) ? r : group_size;
4445
}
4546

47+
uint32_t getNumHardwareThreadsInBlock() {
48+
return getWorkgroupDim(__builtin_amdgcn_workgroup_id_x(),
49+
__builtin_amdgcn_grid_size_x(),
50+
__builtin_amdgcn_workgroup_size_x());
51+
}
52+
4653
LaneMaskTy activemask() { return __builtin_amdgcn_read_exec(); }
4754

4855
LaneMaskTy lanemaskLT() {
@@ -67,13 +74,6 @@ uint32_t getThreadIdInWarp() {
6774

6875
uint32_t getThreadIdInBlock() { return __builtin_amdgcn_workitem_id_x(); }
6976

70-
uint32_t getBlockSize() {
71-
// TODO: verify this logic for generic mode.
72-
return getWorkgroupDim(__builtin_amdgcn_workgroup_id_x(),
73-
__builtin_amdgcn_grid_size_x(),
74-
__builtin_amdgcn_workgroup_size_x());
75-
}
76-
7777
uint32_t getKernelSize() { return __builtin_amdgcn_grid_size_x(); }
7878

7979
uint32_t getBlockId() { return __builtin_amdgcn_workgroup_id_x(); }
@@ -83,12 +83,8 @@ uint32_t getNumberOfBlocks() {
8383
__builtin_amdgcn_workgroup_size_x());
8484
}
8585

86-
uint32_t getNumberOfProcessorElements() {
87-
return getBlockSize();
88-
}
89-
9086
uint32_t getWarpId() {
91-
return mapping::getThreadIdInBlock() / mapping::getWarpSize();
87+
return impl::getThreadIdInBlock() / mapping::getWarpSize();
9288
}
9389

9490
uint32_t getNumberOfWarpsInBlock() {
@@ -104,6 +100,10 @@ uint32_t getNumberOfWarpsInBlock() {
104100
#pragma omp begin declare variant match( \
105101
device = {arch(nvptx, nvptx64)}, implementation = {extension(match_any)})
106102

103+
uint32_t getNumHardwareThreadsInBlock() {
104+
return __nvvm_read_ptx_sreg_ntid_x();
105+
}
106+
107107
constexpr const llvm::omp::GV &getGridValue() {
108108
return llvm::omp::NVPTXGridValues;
109109
}
@@ -126,29 +126,23 @@ LaneMaskTy lanemaskGT() {
126126
return Res;
127127
}
128128

129-
uint32_t getThreadIdInWarp() {
130-
return mapping::getThreadIdInBlock() & (mapping::getWarpSize() - 1);
131-
}
132-
133129
uint32_t getThreadIdInBlock() { return __nvvm_read_ptx_sreg_tid_x(); }
134130

135-
uint32_t getBlockSize() {
136-
return __nvvm_read_ptx_sreg_ntid_x() -
137-
(!mapping::isSPMDMode() * mapping::getWarpSize());
131+
uint32_t getThreadIdInWarp() {
132+
return impl::getThreadIdInBlock() & (mapping::getWarpSize() - 1);
138133
}
139134

140-
uint32_t getKernelSize() { return __nvvm_read_ptx_sreg_nctaid_x(); }
135+
uint32_t getKernelSize() {
136+
return __nvvm_read_ptx_sreg_nctaid_x() *
137+
mapping::getNumberOfProcessorElements();
138+
}
141139

142140
uint32_t getBlockId() { return __nvvm_read_ptx_sreg_ctaid_x(); }
143141

144142
uint32_t getNumberOfBlocks() { return __nvvm_read_ptx_sreg_nctaid_x(); }
145143

146-
uint32_t getNumberOfProcessorElements() {
147-
return __nvvm_read_ptx_sreg_ntid_x();
148-
}
149-
150144
uint32_t getWarpId() {
151-
return mapping::getThreadIdInBlock() / mapping::getWarpSize();
145+
return impl::getThreadIdInBlock() / mapping::getWarpSize();
152146
}
153147

154148
uint32_t getNumberOfWarpsInBlock() {
@@ -164,6 +158,10 @@ uint32_t getWarpSize() { return getGridValue().GV_Warp_Size; }
164158
} // namespace impl
165159
} // namespace _OMP
166160

161+
/// We have to be deliberate about the distinction of `mapping::` and `impl::`
162+
/// below to avoid repeating assumptions or including irrelevant ones.
163+
///{
164+
167165
static bool isInLastWarp() {
168166
uint32_t MainTId = (mapping::getNumberOfProcessorElements() - 1) &
169167
~(mapping::getWarpSize() - 1);
@@ -200,30 +198,60 @@ LaneMaskTy mapping::lanemaskLT() { return impl::lanemaskLT(); }
200198

201199
LaneMaskTy mapping::lanemaskGT() { return impl::lanemaskGT(); }
202200

203-
uint32_t mapping::getThreadIdInWarp() { return impl::getThreadIdInWarp(); }
201+
uint32_t mapping::getThreadIdInWarp() {
202+
uint32_t ThreadIdInWarp = impl::getThreadIdInWarp();
203+
ASSERT(ThreadIdInWarp < impl::getWarpSize());
204+
return ThreadIdInWarp;
205+
}
206+
207+
uint32_t mapping::getThreadIdInBlock() {
208+
uint32_t ThreadIdInBlock = impl::getThreadIdInBlock();
209+
ASSERT(ThreadIdInBlock < impl::getNumHardwareThreadsInBlock());
210+
return ThreadIdInBlock;
211+
}
204212

205-
uint32_t mapping::getThreadIdInBlock() { return impl::getThreadIdInBlock(); }
213+
uint32_t mapping::getWarpSize() { return impl::getWarpSize(); }
206214

207-
uint32_t mapping::getBlockSize() { return impl::getBlockSize(); }
215+
uint32_t mapping::getBlockSize() {
216+
uint32_t BlockSize = mapping::getNumberOfProcessorElements() -
217+
(!mapping::isSPMDMode() * impl::getWarpSize());
218+
return BlockSize;
219+
}
208220

209221
uint32_t mapping::getKernelSize() { return impl::getKernelSize(); }
210222

211-
uint32_t mapping::getBlockId() { return impl::getBlockId(); }
212-
213-
uint32_t mapping::getNumberOfBlocks() { return impl::getNumberOfBlocks(); }
223+
uint32_t mapping::getWarpId() {
224+
uint32_t WarpID = impl::getWarpId();
225+
ASSERT(WarpID < impl::getNumberOfWarpsInBlock());
226+
return WarpID;
227+
}
214228

215-
uint32_t mapping::getNumberOfProcessorElements() {
216-
return impl::getNumberOfProcessorElements();
229+
uint32_t mapping::getBlockId() {
230+
uint32_t BlockId = impl::getBlockId();
231+
ASSERT(BlockId < impl::getNumberOfBlocks());
232+
return BlockId;
217233
}
218234

219-
uint32_t mapping::getWarpId() { return impl::getWarpId(); }
235+
uint32_t mapping::getNumberOfWarpsInBlock() {
236+
uint32_t NumberOfWarpsInBlocks = impl::getNumberOfWarpsInBlock();
237+
ASSERT(impl::getWarpId() < NumberOfWarpsInBlocks);
238+
return NumberOfWarpsInBlocks;
239+
}
220240

221-
uint32_t mapping::getWarpSize() { return impl::getWarpSize(); }
241+
uint32_t mapping::getNumberOfBlocks() {
242+
uint32_t NumberOfBlocks = impl::getNumberOfBlocks();
243+
ASSERT(impl::getBlockId() < NumberOfBlocks);
244+
return NumberOfBlocks;
245+
}
222246

223-
uint32_t mapping::getNumberOfWarpsInBlock() {
224-
return impl::getNumberOfWarpsInBlock();
247+
uint32_t mapping::getNumberOfProcessorElements() {
248+
uint32_t NumberOfProcessorElements = impl::getNumHardwareThreadsInBlock();
249+
ASSERT(impl::getThreadIdInBlock() < NumberOfProcessorElements);
250+
return NumberOfProcessorElements;
225251
}
226252

253+
///}
254+
227255
/// Execution mode
228256
///
229257
///{
@@ -247,7 +275,7 @@ __attribute__((noinline)) uint32_t __kmpc_get_hardware_thread_id_in_block() {
247275

248276
__attribute__((noinline)) uint32_t __kmpc_get_hardware_num_threads_in_block() {
249277
FunctionTracingRAII();
250-
return mapping::getNumberOfProcessorElements();
278+
return impl::getNumHardwareThreadsInBlock();
251279
}
252280
}
253281
#pragma omp end declare target

0 commit comments

Comments
 (0)