@@ -153,7 +153,8 @@ template <typename args_t>
153
153
hsa_status_t launch_kernel (hsa_agent_t dev_agent, hsa_executable_t executable,
154
154
hsa_amd_memory_pool_t kernargs_pool,
155
155
hsa_amd_memory_pool_t coarsegrained_pool,
156
- hsa_queue_t *queue, const LaunchParameters ¶ms,
156
+ hsa_queue_t *queue, rpc_device_t device,
157
+ const LaunchParameters ¶ms,
157
158
const char *kernel_name, args_t kernel_args) {
158
159
// Look up the '_start' kernel in the loaded executable.
159
160
hsa_executable_symbol_t symbol;
@@ -162,10 +163,9 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
162
163
return err;
163
164
164
165
// Register RPC callbacks for the malloc and free functions on HSA.
165
- uint32_t device_id = 0 ;
166
166
auto tuple = std::make_tuple (dev_agent, coarsegrained_pool);
167
167
rpc_register_callback (
168
- device_id , RPC_MALLOC,
168
+ device , RPC_MALLOC,
169
169
[](rpc_port_t port, void *data) {
170
170
auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
171
171
auto &[dev_agent, pool] = *static_cast <decltype (tuple) *>(data);
@@ -182,7 +182,7 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
182
182
},
183
183
&tuple);
184
184
rpc_register_callback (
185
- device_id , RPC_FREE,
185
+ device , RPC_FREE,
186
186
[](rpc_port_t port, void *data) {
187
187
auto free_handler = [](rpc_buffer_t *buffer, void *) {
188
188
if (hsa_status_t err = hsa_amd_memory_pool_free (
@@ -284,12 +284,12 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
284
284
while (hsa_signal_wait_scacquire (
285
285
packet->completion_signal , HSA_SIGNAL_CONDITION_EQ, 0 ,
286
286
/* timeout_hint=*/ 1024 , HSA_WAIT_STATE_ACTIVE) != 0 )
287
- if (rpc_status_t err = rpc_handle_server (device_id ))
287
+ if (rpc_status_t err = rpc_handle_server (device ))
288
288
handle_error (err);
289
289
290
290
// Handle the server one more time in case the kernel exited with a pending
291
291
// send still in flight.
292
- if (rpc_status_t err = rpc_handle_server (device_id ))
292
+ if (rpc_status_t err = rpc_handle_server (device ))
293
293
handle_error (err);
294
294
295
295
// Destroy the resources acquired to launch the kernel and return.
@@ -342,8 +342,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
342
342
handle_error (err);
343
343
344
344
// Obtain a single agent for the device and host to use the HSA memory model.
345
- uint32_t num_devices = 1 ;
346
- uint32_t device_id = 0 ;
347
345
hsa_agent_t dev_agent;
348
346
hsa_agent_t host_agent;
349
347
if (hsa_status_t err = get_agent<HSA_DEVICE_TYPE_GPU>(&dev_agent))
@@ -433,8 +431,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
433
431
handle_error (err);
434
432
435
433
// Set up the RPC server.
436
- if (rpc_status_t err = rpc_init (num_devices))
437
- handle_error (err);
438
434
auto tuple = std::make_tuple (dev_agent, finegrained_pool);
439
435
auto rpc_alloc = [](uint64_t size, void *data) {
440
436
auto &[dev_agent, finegrained_pool] = *static_cast <decltype (tuple) *>(data);
@@ -445,15 +441,16 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
445
441
hsa_amd_agents_allow_access (1 , &dev_agent, nullptr , dev_ptr);
446
442
return dev_ptr;
447
443
};
448
- if (rpc_status_t err = rpc_server_init (device_id, RPC_MAXIMUM_PORT_COUNT,
444
+ rpc_device_t device;
445
+ if (rpc_status_t err = rpc_server_init (&device, RPC_MAXIMUM_PORT_COUNT,
449
446
wavefront_size, rpc_alloc, &tuple))
450
447
handle_error (err);
451
448
452
449
// Register callbacks for the RPC unit tests.
453
450
if (wavefront_size == 32 )
454
- register_rpc_callbacks<32 >(device_id );
451
+ register_rpc_callbacks<32 >(device );
455
452
else if (wavefront_size == 64 )
456
- register_rpc_callbacks<64 >(device_id );
453
+ register_rpc_callbacks<64 >(device );
457
454
else
458
455
handle_error (" Invalid wavefront size" );
459
456
@@ -483,10 +480,10 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
483
480
handle_error (err);
484
481
485
482
void *rpc_client_buffer;
486
- if (hsa_status_t err = hsa_amd_memory_lock (
487
- const_cast <void *>(rpc_get_client_buffer (device_id )),
488
- rpc_get_client_size (),
489
- /* agents=*/ nullptr , 0 , &rpc_client_buffer))
483
+ if (hsa_status_t err =
484
+ hsa_amd_memory_lock ( const_cast <void *>(rpc_get_client_buffer (device )),
485
+ rpc_get_client_size (),
486
+ /* agents=*/ nullptr , 0 , &rpc_client_buffer))
490
487
handle_error (err);
491
488
492
489
// Copy the RPC client buffer to the address pointed to by the symbol.
@@ -496,7 +493,7 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
496
493
handle_error (err);
497
494
498
495
if (hsa_status_t err = hsa_amd_memory_unlock (
499
- const_cast <void *>(rpc_get_client_buffer (device_id ))))
496
+ const_cast <void *>(rpc_get_client_buffer (device ))))
500
497
handle_error (err);
501
498
if (hsa_status_t err = hsa_amd_memory_pool_free (rpc_client_host))
502
499
handle_error (err);
@@ -549,13 +546,13 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
549
546
begin_args_t init_args = {argc, dev_argv, dev_envp};
550
547
if (hsa_status_t err = launch_kernel (
551
548
dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
552
- single_threaded_params, " _begin.kd" , init_args))
549
+ device, single_threaded_params, " _begin.kd" , init_args))
553
550
handle_error (err);
554
551
555
552
start_args_t args = {argc, dev_argv, dev_envp, dev_ret};
556
- if (hsa_status_t err =
557
- launch_kernel (dev_agent, executable, kernargs_pool ,
558
- coarsegrained_pool, queue, params, " _start.kd" , args))
553
+ if (hsa_status_t err = launch_kernel (dev_agent, executable, kernargs_pool,
554
+ coarsegrained_pool, queue, device ,
555
+ params, " _start.kd" , args))
559
556
handle_error (err);
560
557
561
558
void *host_ret;
@@ -575,11 +572,11 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
575
572
end_args_t fini_args = {ret};
576
573
if (hsa_status_t err = launch_kernel (
577
574
dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
578
- single_threaded_params, " _end.kd" , fini_args))
575
+ device, single_threaded_params, " _end.kd" , fini_args))
579
576
handle_error (err);
580
577
581
578
if (rpc_status_t err = rpc_server_shutdown (
582
- device_id , [](void *ptr, void *) { hsa_amd_memory_pool_free (ptr); },
579
+ device , [](void *ptr, void *) { hsa_amd_memory_pool_free (ptr); },
583
580
nullptr ))
584
581
handle_error (err);
585
582
@@ -600,8 +597,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
600
597
if (hsa_status_t err = hsa_code_object_destroy (object))
601
598
handle_error (err);
602
599
603
- if (rpc_status_t err = rpc_shutdown ())
604
- handle_error (err);
605
600
if (hsa_status_t err = hsa_shut_down ())
606
601
handle_error (err);
607
602
0 commit comments