Skip to content

Commit a1a8bb1

Browse files
authored
[libc] Change RPC interface to not use device ids (#87087)
Summary: The current implementation of RPC tied everything to device IDs and forced us to do init / shutdown to manage some global state. This turned out to be a bad idea in situations where we want to track multiple hetergeneous devices that may report the same device ID in the same process. This patch changes the interface to instead create an opaque handle to the internal device and simply allocates it via `new`. The user will then take this device and store it to interface with the attached device. This interface puts the burden of tracking the device identifier to mapped d evices onto the user, but in return heavily simplifies the implementation.
1 parent bdb60e6 commit a1a8bb1

File tree

8 files changed

+107
-181
lines changed

8 files changed

+107
-181
lines changed

libc/utils/gpu/loader/Loader.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ inline void handle_error(rpc_status_t) {
108108
}
109109

110110
template <uint32_t lane_size>
111-
inline void register_rpc_callbacks(uint32_t device_id) {
111+
inline void register_rpc_callbacks(rpc_device_t device) {
112112
static_assert(lane_size == 32 || lane_size == 64, "Invalid Lane size");
113113
// Register the ping test for the `libc` tests.
114114
rpc_register_callback(
115-
device_id, static_cast<rpc_opcode_t>(RPC_TEST_INCREMENT),
115+
device, static_cast<rpc_opcode_t>(RPC_TEST_INCREMENT),
116116
[](rpc_port_t port, void *data) {
117117
rpc_recv_and_send(
118118
port,
@@ -125,7 +125,7 @@ inline void register_rpc_callbacks(uint32_t device_id) {
125125

126126
// Register the interface test callbacks.
127127
rpc_register_callback(
128-
device_id, static_cast<rpc_opcode_t>(RPC_TEST_INTERFACE),
128+
device, static_cast<rpc_opcode_t>(RPC_TEST_INTERFACE),
129129
[](rpc_port_t port, void *data) {
130130
uint64_t cnt = 0;
131131
bool end_with_recv;
@@ -207,7 +207,7 @@ inline void register_rpc_callbacks(uint32_t device_id) {
207207

208208
// Register the stream test handler.
209209
rpc_register_callback(
210-
device_id, static_cast<rpc_opcode_t>(RPC_TEST_STREAM),
210+
device, static_cast<rpc_opcode_t>(RPC_TEST_STREAM),
211211
[](rpc_port_t port, void *data) {
212212
uint64_t sizes[lane_size] = {0};
213213
void *dst[lane_size] = {nullptr};

libc/utils/gpu/loader/amdgpu/Loader.cpp

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ template <typename args_t>
153153
hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
154154
hsa_amd_memory_pool_t kernargs_pool,
155155
hsa_amd_memory_pool_t coarsegrained_pool,
156-
hsa_queue_t *queue, const LaunchParameters &params,
156+
hsa_queue_t *queue, rpc_device_t device,
157+
const LaunchParameters &params,
157158
const char *kernel_name, args_t kernel_args) {
158159
// Look up the '_start' kernel in the loaded executable.
159160
hsa_executable_symbol_t symbol;
@@ -162,10 +163,9 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
162163
return err;
163164

164165
// Register RPC callbacks for the malloc and free functions on HSA.
165-
uint32_t device_id = 0;
166166
auto tuple = std::make_tuple(dev_agent, coarsegrained_pool);
167167
rpc_register_callback(
168-
device_id, RPC_MALLOC,
168+
device, RPC_MALLOC,
169169
[](rpc_port_t port, void *data) {
170170
auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
171171
auto &[dev_agent, pool] = *static_cast<decltype(tuple) *>(data);
@@ -182,7 +182,7 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
182182
},
183183
&tuple);
184184
rpc_register_callback(
185-
device_id, RPC_FREE,
185+
device, RPC_FREE,
186186
[](rpc_port_t port, void *data) {
187187
auto free_handler = [](rpc_buffer_t *buffer, void *) {
188188
if (hsa_status_t err = hsa_amd_memory_pool_free(
@@ -284,12 +284,12 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
284284
while (hsa_signal_wait_scacquire(
285285
packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0,
286286
/*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0)
287-
if (rpc_status_t err = rpc_handle_server(device_id))
287+
if (rpc_status_t err = rpc_handle_server(device))
288288
handle_error(err);
289289

290290
// Handle the server one more time in case the kernel exited with a pending
291291
// send still in flight.
292-
if (rpc_status_t err = rpc_handle_server(device_id))
292+
if (rpc_status_t err = rpc_handle_server(device))
293293
handle_error(err);
294294

295295
// Destroy the resources acquired to launch the kernel and return.
@@ -342,8 +342,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
342342
handle_error(err);
343343

344344
// Obtain a single agent for the device and host to use the HSA memory model.
345-
uint32_t num_devices = 1;
346-
uint32_t device_id = 0;
347345
hsa_agent_t dev_agent;
348346
hsa_agent_t host_agent;
349347
if (hsa_status_t err = get_agent<HSA_DEVICE_TYPE_GPU>(&dev_agent))
@@ -433,8 +431,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
433431
handle_error(err);
434432

435433
// Set up the RPC server.
436-
if (rpc_status_t err = rpc_init(num_devices))
437-
handle_error(err);
438434
auto tuple = std::make_tuple(dev_agent, finegrained_pool);
439435
auto rpc_alloc = [](uint64_t size, void *data) {
440436
auto &[dev_agent, finegrained_pool] = *static_cast<decltype(tuple) *>(data);
@@ -445,15 +441,16 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
445441
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
446442
return dev_ptr;
447443
};
448-
if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT,
444+
rpc_device_t device;
445+
if (rpc_status_t err = rpc_server_init(&device, RPC_MAXIMUM_PORT_COUNT,
449446
wavefront_size, rpc_alloc, &tuple))
450447
handle_error(err);
451448

452449
// Register callbacks for the RPC unit tests.
453450
if (wavefront_size == 32)
454-
register_rpc_callbacks<32>(device_id);
451+
register_rpc_callbacks<32>(device);
455452
else if (wavefront_size == 64)
456-
register_rpc_callbacks<64>(device_id);
453+
register_rpc_callbacks<64>(device);
457454
else
458455
handle_error("Invalid wavefront size");
459456

@@ -483,10 +480,10 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
483480
handle_error(err);
484481

485482
void *rpc_client_buffer;
486-
if (hsa_status_t err = hsa_amd_memory_lock(
487-
const_cast<void *>(rpc_get_client_buffer(device_id)),
488-
rpc_get_client_size(),
489-
/*agents=*/nullptr, 0, &rpc_client_buffer))
483+
if (hsa_status_t err =
484+
hsa_amd_memory_lock(const_cast<void *>(rpc_get_client_buffer(device)),
485+
rpc_get_client_size(),
486+
/*agents=*/nullptr, 0, &rpc_client_buffer))
490487
handle_error(err);
491488

492489
// Copy the RPC client buffer to the address pointed to by the symbol.
@@ -496,7 +493,7 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
496493
handle_error(err);
497494

498495
if (hsa_status_t err = hsa_amd_memory_unlock(
499-
const_cast<void *>(rpc_get_client_buffer(device_id))))
496+
const_cast<void *>(rpc_get_client_buffer(device))))
500497
handle_error(err);
501498
if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_client_host))
502499
handle_error(err);
@@ -549,13 +546,13 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
549546
begin_args_t init_args = {argc, dev_argv, dev_envp};
550547
if (hsa_status_t err = launch_kernel(
551548
dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
552-
single_threaded_params, "_begin.kd", init_args))
549+
device, single_threaded_params, "_begin.kd", init_args))
553550
handle_error(err);
554551

555552
start_args_t args = {argc, dev_argv, dev_envp, dev_ret};
556-
if (hsa_status_t err =
557-
launch_kernel(dev_agent, executable, kernargs_pool,
558-
coarsegrained_pool, queue, params, "_start.kd", args))
553+
if (hsa_status_t err = launch_kernel(dev_agent, executable, kernargs_pool,
554+
coarsegrained_pool, queue, device,
555+
params, "_start.kd", args))
559556
handle_error(err);
560557

561558
void *host_ret;
@@ -575,11 +572,11 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
575572
end_args_t fini_args = {ret};
576573
if (hsa_status_t err = launch_kernel(
577574
dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
578-
single_threaded_params, "_end.kd", fini_args))
575+
device, single_threaded_params, "_end.kd", fini_args))
579576
handle_error(err);
580577

581578
if (rpc_status_t err = rpc_server_shutdown(
582-
device_id, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); },
579+
device, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); },
583580
nullptr))
584581
handle_error(err);
585582

@@ -600,8 +597,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
600597
if (hsa_status_t err = hsa_code_object_destroy(object))
601598
handle_error(err);
602599

603-
if (rpc_status_t err = rpc_shutdown())
604-
handle_error(err);
605600
if (hsa_status_t err = hsa_shut_down())
606601
handle_error(err);
607602

libc/utils/gpu/loader/nvptx/Loader.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ Expected<void *> get_ctor_dtor_array(const void *image, const size_t size,
154154

155155
template <typename args_t>
156156
CUresult launch_kernel(CUmodule binary, CUstream stream,
157-
const LaunchParameters &params, const char *kernel_name,
158-
args_t kernel_args) {
157+
rpc_device_t rpc_device, const LaunchParameters &params,
158+
const char *kernel_name, args_t kernel_args) {
159159
// look up the '_start' kernel in the loaded module.
160160
CUfunction function;
161161
if (CUresult err = cuModuleGetFunction(&function, binary, kernel_name))
@@ -175,11 +175,10 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
175175
handle_error(err);
176176

177177
// Register RPC callbacks for the malloc and free functions on HSA.
178-
uint32_t device_id = 0;
179-
register_rpc_callbacks<32>(device_id);
178+
register_rpc_callbacks<32>(rpc_device);
180179

181180
rpc_register_callback(
182-
device_id, RPC_MALLOC,
181+
rpc_device, RPC_MALLOC,
183182
[](rpc_port_t port, void *data) {
184183
auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
185184
CUstream memory_stream = *static_cast<CUstream *>(data);
@@ -197,7 +196,7 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
197196
},
198197
&memory_stream);
199198
rpc_register_callback(
200-
device_id, RPC_FREE,
199+
rpc_device, RPC_FREE,
201200
[](rpc_port_t port, void *data) {
202201
auto free_handler = [](rpc_buffer_t *buffer, void *data) {
203202
CUstream memory_stream = *static_cast<CUstream *>(data);
@@ -219,12 +218,12 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
219218
// Wait until the kernel has completed execution on the device. Periodically
220219
// check the RPC client for work to be performed on the server.
221220
while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY)
222-
if (rpc_status_t err = rpc_handle_server(device_id))
221+
if (rpc_status_t err = rpc_handle_server(rpc_device))
223222
handle_error(err);
224223

225224
// Handle the server one more time in case the kernel exited with a pending
226225
// send still in flight.
227-
if (rpc_status_t err = rpc_handle_server(device_id))
226+
if (rpc_status_t err = rpc_handle_server(rpc_device))
228227
handle_error(err);
229228

230229
return CUDA_SUCCESS;
@@ -235,7 +234,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
235234
if (CUresult err = cuInit(0))
236235
handle_error(err);
237236
// Obtain the first device found on the system.
238-
uint32_t num_devices = 1;
239237
uint32_t device_id = 0;
240238
CUdevice device;
241239
if (CUresult err = cuDeviceGet(&device, device_id))
@@ -294,17 +292,15 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
294292
if (CUresult err = cuMemsetD32(dev_ret, 0, 1))
295293
handle_error(err);
296294

297-
if (rpc_status_t err = rpc_init(num_devices))
298-
handle_error(err);
299-
300295
uint32_t warp_size = 32;
301296
auto rpc_alloc = [](uint64_t size, void *) -> void * {
302297
void *dev_ptr;
303298
if (CUresult err = cuMemAllocHost(&dev_ptr, size))
304299
handle_error(err);
305300
return dev_ptr;
306301
};
307-
if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT,
302+
rpc_device_t rpc_device;
303+
if (rpc_status_t err = rpc_server_init(&rpc_device, RPC_MAXIMUM_PORT_COUNT,
308304
warp_size, rpc_alloc, nullptr))
309305
handle_error(err);
310306

@@ -321,19 +317,20 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
321317
cuMemcpyDtoH(&rpc_client_host, rpc_client_dev, sizeof(void *)))
322318
handle_error(err);
323319
if (CUresult err =
324-
cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(device_id),
320+
cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(rpc_device),
325321
rpc_get_client_size()))
326322
handle_error(err);
327323

328324
LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
329325
begin_args_t init_args = {argc, dev_argv, dev_envp};
330-
if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
331-
"_begin", init_args))
326+
if (CUresult err = launch_kernel(binary, stream, rpc_device,
327+
single_threaded_params, "_begin", init_args))
332328
handle_error(err);
333329

334330
start_args_t args = {argc, dev_argv, dev_envp,
335331
reinterpret_cast<void *>(dev_ret)};
336-
if (CUresult err = launch_kernel(binary, stream, params, "_start", args))
332+
if (CUresult err =
333+
launch_kernel(binary, stream, rpc_device, params, "_start", args))
337334
handle_error(err);
338335

339336
// Copy the return value back from the kernel and wait.
@@ -345,8 +342,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
345342
handle_error(err);
346343

347344
end_args_t fini_args = {host_ret};
348-
if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
349-
"_end", fini_args))
345+
if (CUresult err = launch_kernel(binary, stream, rpc_device,
346+
single_threaded_params, "_end", fini_args))
350347
handle_error(err);
351348

352349
// Free the memory allocated for the device.
@@ -357,15 +354,13 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
357354
if (CUresult err = cuMemFreeHost(dev_argv))
358355
handle_error(err);
359356
if (rpc_status_t err = rpc_server_shutdown(
360-
device_id, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
357+
rpc_device, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
361358
handle_error(err);
362359

363360
// Destroy the context and the loaded binary.
364361
if (CUresult err = cuModuleUnload(binary))
365362
handle_error(err);
366363
if (CUresult err = cuDevicePrimaryCtxRelease(device))
367364
handle_error(err);
368-
if (rpc_status_t err = rpc_shutdown())
369-
handle_error(err);
370365
return host_ret;
371366
}

libc/utils/gpu/server/llvmlibc_rpc_server.h

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ typedef enum {
2727
RPC_STATUS_SUCCESS = 0x0,
2828
RPC_STATUS_CONTINUE = 0x1,
2929
RPC_STATUS_ERROR = 0x1000,
30-
RPC_STATUS_OUT_OF_RANGE = 0x1001,
31-
RPC_STATUS_UNHANDLED_OPCODE = 0x1002,
32-
RPC_STATUS_INVALID_LANE_SIZE = 0x1003,
33-
RPC_STATUS_NOT_INITIALIZED = 0x1004,
30+
RPC_STATUS_UNHANDLED_OPCODE = 0x1001,
31+
RPC_STATUS_INVALID_LANE_SIZE = 0x1002,
3432
} rpc_status_t;
3533

3634
/// A struct containing an opaque handle to an RPC port. This is what allows the
@@ -45,6 +43,11 @@ typedef struct rpc_buffer_s {
4543
uint64_t data[8];
4644
} rpc_buffer_t;
4745

46+
/// An opaque handle to an RPC server that can be attached to a device.
47+
typedef struct rpc_device_s {
48+
uintptr_t handle;
49+
} rpc_device_t;
50+
4851
/// A function used to allocate \p bytes for use by the RPC server and client.
4952
/// The memory should support asynchronous and atomic access from both the
5053
/// client and server.
@@ -60,34 +63,28 @@ typedef void (*rpc_opcode_callback_ty)(rpc_port_t port, void *data);
6063
/// A callback function to use the port to receive or send a \p buffer.
6164
typedef void (*rpc_port_callback_ty)(rpc_buffer_t *buffer, void *data);
6265

63-
/// Initialize the rpc library for general use on \p num_devices.
64-
rpc_status_t rpc_init(uint32_t num_devices);
65-
66-
/// Shut down the rpc interface.
67-
rpc_status_t rpc_shutdown(void);
68-
69-
/// Initialize the server for a given device.
70-
rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
66+
/// Initialize the server for a given device and return it in \p device.
67+
rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports,
7168
uint32_t lane_size, rpc_alloc_ty alloc,
7269
void *data);
7370

7471
/// Shut down the server for a given device.
75-
rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
72+
rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc,
7673
void *data);
7774

7875
/// Queries the RPC clients at least once and performs server-side work if there
7976
/// are any active requests. Runs until all work on the server is completed.
80-
rpc_status_t rpc_handle_server(uint32_t device_id);
77+
rpc_status_t rpc_handle_server(rpc_device_t rpc_device);
8178

8279
/// Register a callback to handle an opcode from the RPC client. The associated
8380
/// data must remain accessible as long as the user intends to handle the server
8481
/// with this callback.
85-
rpc_status_t rpc_register_callback(uint32_t device_id, uint16_t opcode,
82+
rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint16_t opcode,
8683
rpc_opcode_callback_ty callback, void *data);
8784

8885
/// Obtain a pointer to a local client buffer that can be copied directly to the
8986
/// other process using the address stored at the rpc client symbol name.
90-
const void *rpc_get_client_buffer(uint32_t device_id);
87+
const void *rpc_get_client_buffer(rpc_device_t device);
9188

9289
/// Returns the size of the client in bytes to be used for a memory copy.
9390
uint64_t rpc_get_client_size();

0 commit comments

Comments
 (0)