Skip to content

Commit 63590f8

Browse files
authored
[OpenMP] cleanup of the emissary code by creating a utility to extrac… (llvm#899)
2 parents 531b0b3 + 3c6129f commit 63590f8

File tree

8 files changed

+455
-346
lines changed

8 files changed

+455
-346
lines changed

offload/DeviceRTL/include/EmissaryIds.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ typedef enum {
2929
_ockl_asan_report_idx,
3030
} offload_emis_print_t;
3131

32+
typedef enum {
33+
_MPI_INVALID,
34+
_MPI_Send_idx,
35+
_MPI_Recv_idx,
36+
} offload_emis_mpi_t;
37+
3238
/// The vargs function used by emissary API device stubs
3339
unsigned long long _emissary_exec(unsigned long long, ...);
3440

@@ -53,4 +59,28 @@ typedef enum {
5359
_FortranAStopStatementText_idx,
5460
} offload_emis_fortrt_idx;
5561

62+
// mpi.h (needed for MPI types) will not compile while building DeviceRTL,
63+
// So emissary stubs for MPI functions can NOT be in libomptarget.bc.
64+
// These are skipped whild building DeviceRTL because compilation of DeviceRTL
65+
// does not have include mpi.h. The user will build these stubs on their
66+
// device pass when they include EmissaryIds.h.
67+
68+
#if defined(__NVPTX__) || defined(__AMDGCN__)
69+
#if defined(__has_include)
70+
#if __has_include("mpi.h")
71+
#include "mpi.h"
72+
extern "C" int MPI_Send(const void *buf, int count, MPI_Datatype datatype,
73+
int dest, int tag, MPI_Comm comm) {
74+
return (int)_emissary_exec(_PACK_EMIS_IDS(EMIS_ID_MPI, _MPI_Send_idx), buf,
75+
count, datatype, dest, tag, comm);
76+
}
77+
extern "C" int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source,
78+
int tag, MPI_Comm comm, MPI_Status *st) {
79+
return (int)_emissary_exec(_PACK_EMIS_IDS(EMIS_ID_MPI, _MPI_Recv_idx), buf,
80+
count, datatype, source, tag, comm, st);
81+
}
82+
#endif
83+
#endif
84+
#endif
85+
5686
#endif // OFFLOAD_EMISSARY_IDS_H

offload/plugins-nextgen/common/CMakeLists.txt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11

22
if(OFFLOAD_ENABLE_EMISSARY_APIS)
3-
#set(emissary_include_dir ${CMAKE_CURRENT_SOURCE_DIR}/../../emissary/include)
43
set(emissary_sources
54
src/Emissary.cpp
65
src/EmissaryFortrt.cpp
76
src/EmissaryPrint.cpp
87
)
8+
set(OFFLOAD_EMISSARY_MPI_INCLUDE ""
9+
CACHE STRING "MPI include directory for building MPI Emissary API")
10+
# dont enable the MPI Emissary api unless we have an MPI include dir
11+
# that contains mpi.h needed to build the variadic wrappers.
12+
if(OFFLOAD_EMISSARY_MPI_INCLUDE)
13+
list(APPEND emissary_sources src/EmissaryMPI.cpp)
14+
endif()
915
endif()
1016

1117
# NOTE: Don't try to build `PluginInterface` using `add_llvm_library` because we
@@ -38,6 +44,9 @@ target_link_libraries(PluginCommon PRIVATE llvm-libc-common-utilities)
3844
if(OFFLOAD_ENABLE_EMISSARY_APIS)
3945
target_link_libraries(PluginCommon PRIVATE flang_rt.runtime
4046
-L${CMAKE_BINARY_DIR}/../../lib -L${CMAKE_INSTALL_PREFIX}/lib)
47+
if(OFFLOAD_EMISSARY_MPI_INCLUDE)
48+
target_include_directories(PluginCommon PUBLIC ${OFFLOAD_EMISSARY_MPI_INCLUDE})
49+
endif()
4150
endif()
4251
if(TARGET llvmlibc_rpc_server AND ${LIBOMPTARGET_GPU_LIBC_SUPPORT})
4352
target_link_libraries(PluginCommon PRIVATE llvmlibc_rpc_server)
@@ -76,7 +85,6 @@ target_include_directories(PluginCommon PUBLIC
7685
${LIBOMPTARGET_LLVM_INCLUDE_DIRS}
7786
${LIBOMPTARGET_BINARY_INCLUDE_DIR}
7887
${LIBOMPTARGET_INCLUDE_DIR}
79-
# ${emissary_include_dir}
8088
)
8189

8290
set_target_properties(PluginCommon PROPERTIES POSITION_INDEPENDENT_CODE ON)

offload/plugins-nextgen/common/include/Emissary.h

Lines changed: 195 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,69 @@
1414
#ifndef OFFLOAD_EMISSARY_H
1515
#define OFFLOAD_EMISSARY_H
1616

17+
/// This structure is created by emisExtractArgBuf to make it easier
18+
/// to get values from the data buffer passed by rpc.
19+
typedef struct {
20+
unsigned int DataLen;
21+
unsigned int NumArgs;
22+
unsigned int emisid;
23+
unsigned int emisfnid;
24+
size_t data_not_used;
25+
char *keyptr;
26+
char *argptr;
27+
char *strptr;
28+
} emisArgBuf_t;
29+
1730
typedef unsigned long long emis_return_t;
18-
typedef uint64_t emis_uint64_t(void *, ...);
19-
typedef uint32_t emis_uint32_t(void *, ...);
31+
typedef emis_return_t emisfn_t(void *, ...);
2032

2133
// MAXVARGS is the maximum number of args in an emissary function
22-
// To increase this number be sure to updated _call_fnptr
34+
// To increase this number, update EmissaryCallFnptr below
2335
#define MAXVARGS 32
2436

2537
extern "C" {
26-
emis_return_t _emissary_execute(void *data);
27-
emis_return_t _emissary_execute_fortrt(uint32_t func_id, void *data,
28-
uint32_t sz);
29-
emis_return_t _emissary_execute_print(uint32_t func_id, void *data,
30-
uint32_t sz);
38+
39+
/// Called by rpc after receiving emissary argument buffer
40+
emis_return_t Emissary(char *data);
41+
42+
/// Called by Emissary for all Fortrt emissary functions
43+
emis_return_t EmissaryFortrt(char *data, emisArgBuf_t *ab);
44+
45+
/// Called by Emissary for all misc print functions
46+
emis_return_t EmissaryPrint(char *data, emisArgBuf_t *ab);
47+
48+
/// Called by Emissary for all MPI emissary API functions
49+
emis_return_t EmissaryMPI(char *data, emisArgBuf_t *ab);
50+
51+
/// Called by Emissary for all HDF5 Emissary API functions
52+
emis_return_t EmissaryHDF5(char *data, emisArgBuf_t *ab);
53+
54+
/// Called by Emissary to build the emisArgBuf_t structure from the emissary
55+
/// data buffer sent to the CPU by rpc. This buffer is created by clang CodeGen
56+
/// when variadic function _emissary_exec(...) is encountered when compiling
57+
// /the device stub for each emissary function.
58+
void emisExtractArgBuf(char *buf, emisArgBuf_t *ab);
3159

3260
/// Get uint32 value extended to uint64_t value from a char ptr
3361
uint64_t getuint32(char *val);
3462
/// Get uint64_t value from a char ptr
3563
uint64_t getuint64(char *val);
3664
/// Get a function pointer from a char ptr
3765
void *getfnptr(char *val);
38-
/// build argument array
39-
uint32_t _build_vargs_array(int NumArgs, char *keyptr, char *dataptr,
66+
67+
/// Builds the array of pointers passed to V_ functions
68+
uint32_t EmissaryBuildVargs(int NumArgs, char *keyptr, char *dataptr,
4069
char *strptr, size_t *data_not_used,
4170
uint64_t *a[MAXVARGS]);
4271

4372
} // end extern "C"
4473

45-
/// Make the vargs function call to the function pointer fnptr
46-
/// by casting fnptr to vfnptr. Return uint32_t return code
74+
/// Call the associated V_ function
4775
template <typename T, typename FT>
48-
extern uint32_t _call_fnptr(uint32_t NumArgs, void *fnptr,
49-
uint64_t *a[MAXVARGS], T *rv);
76+
extern T EmissaryCallFnptr(uint32_t NumArgs, void *fnptr,
77+
uint64_t *a[MAXVARGS]);
5078

51-
// Error return codes to _emissary_exec
79+
// Error return codes (deprecated)
5280
typedef enum service_rc {
5381
_RC_SUCCESS = 0,
5482
_RC_STATUS_UNKNOWN = 1,
@@ -110,4 +138,156 @@ enum TypeID {
110138
TargetExtTyID, ///< Target extension type
111139
};
112140

141+
template <typename T, typename FT>
142+
extern T EmissaryCallFnptr(uint32_t NumArgs, void *fnptr,
143+
uint64_t *a[MAXVARGS]) {
144+
T rv;
145+
FT *vfnptr = (FT *)fnptr;
146+
switch (NumArgs) {
147+
case 1:
148+
rv = (T)vfnptr(fnptr, a[0]);
149+
break;
150+
case 2:
151+
rv = (T)vfnptr(fnptr, a[0], a[1]);
152+
break;
153+
case 3:
154+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2]);
155+
break;
156+
case 4:
157+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3]);
158+
break;
159+
case 5:
160+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4]);
161+
break;
162+
case 6:
163+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5]);
164+
break;
165+
case 7:
166+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6]);
167+
break;
168+
case 8:
169+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]);
170+
break;
171+
case 9:
172+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]);
173+
break;
174+
case 10:
175+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
176+
a[9]);
177+
break;
178+
case 11:
179+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
180+
a[9], a[10]);
181+
break;
182+
case 12:
183+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
184+
a[9], a[10], a[11]);
185+
break;
186+
case 13:
187+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
188+
a[9], a[10], a[11], a[12]);
189+
break;
190+
case 14:
191+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
192+
a[9], a[10], a[11], a[12], a[13]);
193+
break;
194+
case 15:
195+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
196+
a[9], a[10], a[11], a[12], a[13], a[14]);
197+
break;
198+
case 16:
199+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
200+
a[9], a[10], a[11], a[12], a[13], a[14], a[15]);
201+
break;
202+
case 17:
203+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
204+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16]);
205+
break;
206+
case 18:
207+
rv =
208+
(T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
209+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17]);
210+
break;
211+
case 19:
212+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
213+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
214+
a[18]);
215+
break;
216+
case 20:
217+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
218+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
219+
a[18], a[19]);
220+
break;
221+
case 21:
222+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
223+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
224+
a[18], a[19], a[20]);
225+
break;
226+
case 22:
227+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
228+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
229+
a[18], a[19], a[20], a[21]);
230+
break;
231+
case 23:
232+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
233+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
234+
a[18], a[19], a[20], a[21], a[22]);
235+
break;
236+
case 24:
237+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
238+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
239+
a[18], a[19], a[20], a[21], a[22], a[23]);
240+
break;
241+
case 25:
242+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
243+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
244+
a[18], a[19], a[20], a[21], a[22], a[23], a[24]);
245+
break;
246+
case 26:
247+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
248+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
249+
a[18], a[19], a[20], a[21], a[22], a[23], a[24], a[25]);
250+
break;
251+
case 27:
252+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
253+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
254+
a[18], a[19], a[20], a[21], a[22], a[23], a[24], a[25],
255+
a[26]);
256+
break;
257+
case 28:
258+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
259+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
260+
a[18], a[19], a[20], a[21], a[22], a[23], a[24], a[25],
261+
a[26], a[27]);
262+
break;
263+
case 29:
264+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
265+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
266+
a[18], a[19], a[20], a[21], a[22], a[23], a[24], a[25],
267+
a[26], a[27], a[28]);
268+
break;
269+
case 30:
270+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
271+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
272+
a[18], a[19], a[20], a[21], a[22], a[23], a[24], a[25],
273+
a[26], a[27], a[28], a[29]);
274+
break;
275+
case 31:
276+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
277+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
278+
a[18], a[19], a[20], a[21], a[22], a[23], a[24], a[25],
279+
a[26], a[27], a[28], a[29], a[30]);
280+
break;
281+
case 32:
282+
rv = (T)vfnptr(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8],
283+
a[9], a[10], a[11], a[12], a[13], a[14], a[15], a[16], a[17],
284+
a[18], a[19], a[20], a[21], a[22], a[23], a[24], a[25],
285+
a[26], a[27], a[28], a[29], a[30], a[31]);
286+
break;
287+
default:
288+
rv = 0;
289+
}
290+
return rv;
291+
}
292+
113293
#endif // OFFLOAD_EMISSARY_H

0 commit comments

Comments
 (0)