Skip to content

Commit e537c83

Browse files
committed
[libc] Add basic support for calling host functions from the GPU
This patch adds the `rpc_host_call` function as a GPU extension. This is exported from the `libc` project to use the RPC interface to call a function pointer via RPC any copying the arguments by-value. The interface can only support a single void pointer argument much like pthreads. The function call here is the bare-bones version of what's required for OpenMP reverse offloading. Full support will require interfacing with the mapping table, nowait support, etc. I decided to test this interface in `libomptarget` as that will be the primary consumer and it would be more difficult to make a test in `libc` due to the testing infrastructure not really having a concept of the "host" as it runs directly on the GPU as if it were a CPU target. Reviewed By: jplehr Differential Revision: https://reviews.llvm.org/D155003
1 parent 68cd1db commit e537c83

File tree

8 files changed

+138
-4
lines changed

8 files changed

+138
-4
lines changed

libc/config/gpu/entrypoints.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ set(TARGET_LIBC_ENTRYPOINTS
9797

9898
# gpu/rpc.h entrypoints
9999
libc.src.gpu.rpc_reset
100+
libc.src.gpu.rpc_host_call
100101
)
101102

102103
set(TARGET_LIBM_ENTRYPOINTS

libc/include/llvm-libc-types/rpc_opcodes_t.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ typedef enum : unsigned short {
1919
RPC_CLOSE_FILE = 6,
2020
RPC_MALLOC = 7,
2121
RPC_FREE = 8,
22+
RPC_HOST_CALL = 9,
2223
// TODO: Move these out of here and handle then with custom handlers in the
2324
// loader.
2425
RPC_TEST_INCREMENT = 1000,

libc/spec/gpu_ext.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ def GPUExtensions : StandardSpec<"GPUExtensions"> {
1010
RetValSpec<VoidType>,
1111
[ArgSpec<UnsignedIntType>, ArgSpec<VoidPtr>]
1212
>,
13+
FunctionSpec<
14+
"rpc_host_call",
15+
RetValSpec<VoidType>,
16+
[ArgSpec<VoidPtr>, ArgSpec<VoidPtr>, ArgSpec<SizeTType>]
17+
>,
1318
]
1419
>;
1520
let Headers = [

libc/src/gpu/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,14 @@ add_entrypoint_object(
88
libc.src.__support.RPC.rpc_client
99
libc.src.__support.GPU.utils
1010
)
11+
12+
add_entrypoint_object(
13+
rpc_host_call
14+
SRCS
15+
rpc_host_call.cpp
16+
HDRS
17+
rpc_host_call.h
18+
DEPENDS
19+
libc.src.__support.RPC.rpc_client
20+
libc.src.__support.GPU.utils
21+
)

libc/src/gpu/rpc_host_call.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===---------- GPU implementation of the external RPC call function ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/gpu/rpc_host_call.h"
10+
11+
#include "llvm-libc-types/rpc_opcodes_t.h"
12+
#include "src/__support/GPU/utils.h"
13+
#include "src/__support/RPC/rpc_client.h"
14+
#include "src/__support/common.h"
15+
16+
namespace __llvm_libc {
17+
18+
// This calls the associated function pointer on the RPC server with the given
19+
// arguments. We expect that the pointer here is a valid pointer on the server.
20+
LLVM_LIBC_FUNCTION(void, rpc_host_call, (void *fn, void *data, size_t size)) {
21+
rpc::Client::Port port = rpc::client.open<RPC_HOST_CALL>();
22+
port.send_n(data, size);
23+
port.send([=](rpc::Buffer *buffer) {
24+
buffer->data[0] = reinterpret_cast<uintptr_t>(fn);
25+
});
26+
port.recv([](rpc::Buffer *) {});
27+
port.close();
28+
}
29+
30+
} // namespace __llvm_libc

libc/src/gpu/rpc_host_call.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===-- Implementation header for RPC functions -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_LIBC_SRC_GPU_RPC_HOST_CALL_H
10+
#define LLVM_LIBC_SRC_GPU_RPC_HOST_CALL_H
11+
12+
#include <stddef.h> // size_t
13+
14+
namespace __llvm_libc {
15+
16+
void rpc_host_call(void *fn, void *buffer, size_t size);
17+
18+
} // namespace __llvm_libc
19+
20+
#endif // LLVM_LIBC_SRC_GPU_RPC_H_HOST_CALL

libc/utils/gpu/server/rpc_server.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,18 @@ struct Server {
129129
});
130130
break;
131131
}
132+
case RPC_HOST_CALL: {
133+
uint64_t sizes[rpc::MAX_LANE_SIZE] = {0};
134+
void *args[rpc::MAX_LANE_SIZE] = {nullptr};
135+
port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; });
136+
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
137+
reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
138+
});
139+
port->send([&](rpc::Buffer *, uint32_t id) {
140+
delete[] reinterpret_cast<uint8_t *>(args[id]);
141+
});
142+
break;
143+
}
132144
// TODO: Move handling of these test cases to the loader implementation.
133145
case RPC_TEST_INCREMENT: {
134146
port->recv_and_send([](rpc::Buffer *buffer) {
@@ -341,7 +353,7 @@ uint64_t rpc_get_client_size() { return sizeof(rpc::Client); }
341353
using ServerPort = std::variant<rpc::Server<1>::Port *, rpc::Server<32>::Port *,
342354
rpc::Server<64>::Port *>;
343355

344-
ServerPort getPort(rpc_port_t ref) {
356+
ServerPort get_port(rpc_port_t ref) {
345357
if (ref.lane_size == 1)
346358
return reinterpret_cast<rpc::Server<1>::Port *>(ref.handle);
347359
else if (ref.lane_size == 32)
@@ -353,7 +365,7 @@ ServerPort getPort(rpc_port_t ref) {
353365
}
354366

355367
void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
356-
auto port = getPort(ref);
368+
auto port = get_port(ref);
357369
std::visit(
358370
[=](auto &port) {
359371
port->send([=](rpc::Buffer *buffer) {
@@ -364,7 +376,7 @@ void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
364376
}
365377

366378
void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
367-
auto port = getPort(ref);
379+
auto port = get_port(ref);
368380
std::visit(
369381
[=](auto &port) {
370382
port->recv([=](rpc::Buffer *buffer) {
@@ -376,7 +388,7 @@ void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
376388

377389
void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
378390
void *data) {
379-
auto port = getPort(ref);
391+
auto port = get_port(ref);
380392
std::visit(
381393
[=](auto &port) {
382394
port->recv_and_send([=](rpc::Buffer *buffer) {
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %libomptarget-compile-run-and-check-generic
2+
3+
// REQUIRES: libc
4+
5+
#include <assert.h>
6+
#include <omp.h>
7+
#include <stdio.h>
8+
9+
#pragma omp begin declare variant match(device = {kind(gpu)})
10+
// Extension provided by the 'libc' project.
11+
void rpc_host_call(void *fn, void *args, size_t size);
12+
#pragma omp declare target to(rpc_host_call) device_type(nohost)
13+
#pragma omp end declare variant
14+
15+
#pragma omp begin declare variant match(device = {kind(cpu)})
16+
// Dummy host implementation to make this work for all targets.
17+
void rpc_host_call(void *fn, void *args, size_t size) {
18+
((void (*)(void *))fn)(args);
19+
}
20+
#pragma omp end declare variant
21+
22+
typedef struct args_s {
23+
int thread_id;
24+
int block_id;
25+
} args_t;
26+
27+
// CHECK-DAG: Thread: 0, Block: 0
28+
// CHECK-DAG: Thread: 1, Block: 0
29+
// CHECK-DAG: Thread: 0, Block: 1
30+
// CHECK-DAG: Thread: 1, Block: 1
31+
// CHECK-DAG: Thread: 0, Block: 2
32+
// CHECK-DAG: Thread: 1, Block: 2
33+
// CHECK-DAG: Thread: 0, Block: 3
34+
// CHECK-DAG: Thread: 1, Block: 3
35+
void foo(void *data) {
36+
assert(omp_is_initial_device() && "Not executing on host?");
37+
args_t *args = (args_t *)data;
38+
printf("Thread: %d, Block: %d\n", args->thread_id, args->block_id);
39+
}
40+
41+
void *fn_ptr = NULL;
42+
#pragma omp declare target to(fn_ptr)
43+
44+
int main() {
45+
fn_ptr = (void *)&foo;
46+
#pragma omp target update to(fn_ptr)
47+
48+
#pragma omp target teams num_teams(4)
49+
#pragma omp parallel num_threads(2)
50+
{
51+
args_t args = {omp_get_thread_num(), omp_get_team_num()};
52+
rpc_host_call(fn_ptr, &args, sizeof(args_t));
53+
}
54+
}

0 commit comments

Comments
 (0)