Skip to content

Commit 97c64a5

Browse files
committed
win32 support
1 parent 9f9b410 commit 97c64a5

File tree

4 files changed

+135
-56
lines changed

4 files changed

+135
-56
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,10 @@ endif()
497497
if (LLAMA_RPC)
498498
add_compile_definitions(GGML_USE_RPC)
499499

500+
if (WIN32)
501+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ws2_32)
502+
endif()
503+
500504
set(GGML_HEADERS_RPC ggml-rpc.h)
501505
set(GGML_SOURCES_RPC ggml-rpc.cpp)
502506
endif()

examples/rpc/rpc-server.cpp

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
#include "ggml-rpc.h"
1010
#include <memory>
1111
#include <string>
12-
#include <sys/types.h>
13-
#include <sys/socket.h>
14-
#include <netinet/in.h>
15-
#include <netinet/tcp.h>
16-
#include <arpa/inet.h>
12+
#ifndef _WIN32
13+
# include <sys/socket.h>
14+
# include <sys/types.h>
15+
# include <arpa/inet.h>
16+
# include <netinet/in.h>
17+
# include <netinet/tcp.h>
18+
# include <netdb.h>
19+
# include <unistd.h>
20+
#endif
1721
#include <stdio.h>
1822
#include <stdlib.h>
19-
#include <unistd.h>
2023

2124
static ggml_backend_t create_backend() {
2225
ggml_backend_t backend = NULL;
@@ -52,10 +55,24 @@ static void get_backend_memory(size_t * free_mem, size_t * total_mem) {
5255
#endif
5356
}
5457

55-
static int create_server_socket(const char * host, int port) {
56-
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
58+
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
59+
#ifdef _WIN32
60+
if (fd == INVALID_SOCKET) {
61+
return nullptr;
62+
}
63+
#else
64+
if (fd < 0) {
65+
return nullptr;
66+
}
67+
#endif
68+
return std::make_shared<socket_t>(fd);
69+
}
70+
71+
static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
72+
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
73+
auto sock = make_socket(sockfd);
5774
if (sockfd < 0) {
58-
return -1;
75+
return nullptr;
5976
}
6077

6178
struct sockaddr_in serv_addr;
@@ -64,16 +81,24 @@ static int create_server_socket(const char * host, int port) {
6481
serv_addr.sin_port = htons(port);
6582

6683
if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
67-
return -1;
84+
return nullptr;
6885
}
6986
if (listen(sockfd, 5) < 0) {
70-
return -1;
87+
return nullptr;
7188
}
72-
return sockfd;
89+
return sock;
7390
}
7491

7592
int main(int argc, char * argv[])
7693
{
94+
#ifdef _WIN32
95+
WSADATA wsaData;
96+
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
97+
if (res != 0) {
98+
fprintf(stderr, "WSAStartup failed: %d\n", res);
99+
return 1;
100+
}
101+
#endif
77102
if (argc < 3) {
78103
fprintf(stderr, "Usage: %s <host> <port>\n", argv[0]);
79104
return 1;
@@ -88,33 +113,33 @@ int main(int argc, char * argv[])
88113
}
89114

90115
printf("Starting RPC server on %s:%d\n", host, port);
91-
int server_socket = create_server_socket(host, port);
92-
if (server_socket < 0) {
116+
auto server_socket = create_server_socket(host, port);
117+
if (server_socket == nullptr) {
93118
fprintf(stderr, "Failed to create server socket\n");
94119
return 1;
95120
}
96121
while (true) {
97-
struct sockaddr_in cli_addr;
98-
socklen_t clilen = sizeof(cli_addr);
99-
int client_socket = accept(server_socket, (struct sockaddr *) &cli_addr, &clilen);
100-
if (client_socket < 0) {
122+
auto client_socket_fd = accept(server_socket->fd, NULL, NULL);
123+
auto client_socket = make_socket(client_socket_fd);
124+
if (client_socket == nullptr) {
101125
fprintf(stderr, "Failed to accept client connection\n");
102126
return 1;
103127
}
104128
// set TCP_NODELAY to disable Nagle's algorithm
105129
int flag = 1;
106-
int ret = setsockopt(client_socket, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(int));
130+
int ret = setsockopt(client_socket->fd, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(int));
107131
if (ret < 0) {
108132
fprintf(stderr, "Failed to set TCP_NODELAY\n");
109-
close(client_socket);
110133
continue;
111134
}
112135
size_t free_mem, total_mem;
113136
get_backend_memory(&free_mem, &total_mem);
114137
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
115-
rpc_serve_client(backend, client_socket, free_mem, total_mem);
138+
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
116139
printf("Client connection closed\n");
117-
close(client_socket);
118140
}
141+
#ifdef _WIN32
142+
WSACleanup();
143+
#endif
119144
return 0;
120145
}

ggml-rpc.cpp

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
#include "ggml.h"
33
#include "ggml-backend-impl.h"
44

5-
#include <memory>
65
#include <string>
76
#include <vector>
87
#include <unordered_map>
98
#include <unordered_set>
10-
#include <sys/socket.h>
11-
#include <sys/types.h>
12-
#include <netinet/in.h>
13-
#include <netinet/tcp.h>
14-
#include <netdb.h>
9+
#ifndef _WIN32
10+
# include <sys/socket.h>
11+
# include <sys/types.h>
12+
# include <netinet/in.h>
13+
# include <netinet/tcp.h>
14+
# include <netdb.h>
15+
# include <unistd.h>
16+
#endif
1517
#include <string.h>
16-
#include <unistd.h>
1718

1819
#define UNUSED GGML_UNUSED
1920

@@ -24,23 +25,19 @@
2425
#define GGML_PRINT_DEBUG(...)
2526
#endif
2627

27-
// RPC data structures
28+
#ifdef _WIN32
29+
using ssize_t = __int64;
30+
#endif
2831

29-
struct sockfd {
30-
int fd;
31-
sockfd(int fd) : fd(fd) {}
32-
~sockfd() {
33-
close(fd);
34-
}
35-
};
32+
// RPC data structures
3633

3734
static ggml_guid_t ggml_backend_rpc_guid() {
3835
static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
3936
return &guid;
4037
}
4138

4239
struct ggml_backend_rpc_buffer_type_context {
43-
std::shared_ptr<sockfd> sock;
40+
std::shared_ptr<socket_t> sock;
4441
std::string name;
4542
size_t alignment;
4643
size_t max_size;
@@ -49,27 +46,47 @@ struct ggml_backend_rpc_buffer_type_context {
4946
struct ggml_backend_rpc_context {
5047
std::string endpoint;
5148
std::string name;
52-
std::shared_ptr<sockfd> sock;
49+
std::shared_ptr<socket_t> sock;
5350
ggml_backend_buffer_type_t buft;
5451
};
5552

5653
struct ggml_backend_rpc_buffer_context {
57-
std::shared_ptr<sockfd> sock;
54+
std::shared_ptr<socket_t> sock;
5855
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
5956
uint64_t remote_ptr;
6057
std::string name;
6158
};
6259

63-
6460
// RPC helper functions
6561

66-
static std::shared_ptr<sockfd> socket_connect(const char * host, int port) {
62+
socket_t::~socket_t() {
63+
#ifdef _WIN32
64+
closesocket(this->fd);
65+
#else
66+
close(this->fd);
67+
#endif
68+
}
69+
70+
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
71+
#ifdef _WIN32
72+
if (fd == INVALID_SOCKET) {
73+
return nullptr;
74+
}
75+
#else
76+
if (fd < 0) {
77+
return nullptr;
78+
}
79+
#endif
80+
return std::make_shared<socket_t>(fd);
81+
}
82+
83+
static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
6784
struct sockaddr_in addr;
68-
int sock = socket(AF_INET, SOCK_STREAM, 0);
69-
if (sock < 0) {
85+
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
86+
auto sock_ptr = make_socket(sockfd);
87+
if (sock_ptr == nullptr) {
7088
return nullptr;
7189
}
72-
auto sock_ptr = std::make_shared<sockfd>(sock);
7390
// set TCP_NODELAY to disable Nagle's algorithm
7491
int flag = 1;
7592
int ret = setsockopt(sock_ptr->fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
@@ -83,17 +100,17 @@ static std::shared_ptr<sockfd> socket_connect(const char * host, int port) {
83100
fprintf(stderr, "Cannot resolve host '%s'\n", host);
84101
return nullptr;
85102
}
86-
bcopy((char *)server->h_addr, (char *)&addr.sin_addr.s_addr, server->h_length);
103+
memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
87104
if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
88105
return nullptr;
89106
}
90107
return sock_ptr;
91108
}
92109

93-
static bool send_data(int sockfd, const void * data, size_t size) {
110+
static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
94111
size_t bytes_sent = 0;
95112
while (bytes_sent < size) {
96-
ssize_t n = send(sockfd, (const uint8_t *)data + bytes_sent, size - bytes_sent, 0);
113+
ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
97114
if (n < 0) {
98115
return false;
99116
}
@@ -102,10 +119,10 @@ static bool send_data(int sockfd, const void * data, size_t size) {
102119
return true;
103120
}
104121

105-
static bool recv_data(int sockfd, void * data, size_t size) {
122+
static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
106123
size_t bytes_recv = 0;
107124
while (bytes_recv < size) {
108-
ssize_t n = recv(sockfd, (uint8_t *)data + bytes_recv, size - bytes_recv, 0);
125+
ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
109126
if (n <= 0) {
110127
return false;
111128
}
@@ -116,7 +133,7 @@ static bool recv_data(int sockfd, void * data, size_t size) {
116133

117134
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
118135
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
119-
static bool send_rpc_cmd(const std::shared_ptr<sockfd> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
136+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
120137
uint8_t cmd_byte = cmd;
121138
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
122139
return false;
@@ -348,7 +365,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
348365
return buffer;
349366
}
350367

351-
static size_t get_alignment(const std::shared_ptr<sockfd> & sock) {
368+
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
352369
// input serialization format: | 0 bytes |
353370
std::vector<uint8_t> input;
354371
std::vector<uint8_t> output;
@@ -366,7 +383,7 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_
366383
return buft_ctx->alignment;
367384
}
368385

369-
static size_t get_max_size(const std::shared_ptr<sockfd> & sock) {
386+
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
370387
// input serialization format: | 0 bytes |
371388
std::vector<uint8_t> input;
372389
std::vector<uint8_t> output;
@@ -522,6 +539,15 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
522539
if (instances.find(endpoint) != instances.end()) {
523540
return instances[endpoint];
524541
}
542+
#ifdef _WIN32
543+
{
544+
WSADATA wsaData;
545+
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
546+
if (res != 0) {
547+
return nullptr;
548+
}
549+
}
550+
#endif
525551
GGML_PRINT_DEBUG("Connecting to %s\n", endpoint.c_str());
526552
// split the endpoint into host and port
527553
size_t pos = endpoint.find(":");
@@ -565,7 +591,7 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
565591
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
566592
}
567593

568-
static void get_device_memory(const std::shared_ptr<sockfd> & sock, size_t * free, size_t * total) {
594+
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
569595
// input serialization format: | 0 bytes |
570596
std::vector<uint8_t> input;
571597
std::vector<uint8_t> output;
@@ -779,7 +805,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
779805
ggml_free(ctx);
780806
}
781807

782-
void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem) {
808+
void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
783809
while (true) {
784810
uint8_t cmd;
785811
if (!recv_data(sockfd, &cmd, 1)) {

ggml-rpc.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,35 @@
33
#include "ggml.h"
44
#include "ggml-backend.h"
55
#include <string>
6+
#include <memory>
7+
#ifdef _WIN32
8+
# define WIN32_LEAN_AND_MEAN
9+
# ifndef NOMINMAX
10+
# define NOMINMAX
11+
# endif
12+
# include <windows.h>
13+
# include <winsock2.h>
14+
#endif
615

716
#ifdef __cplusplus
817
extern "C" {
918
#endif
1019

20+
// cross-platform socket fd
21+
#ifdef _WIN32
22+
typedef SOCKET sockfd_t;
23+
#else
24+
typedef int sockfd_t;
25+
#endif
26+
27+
// cross-platform socket
28+
struct socket_t {
29+
sockfd_t fd;
30+
socket_t(sockfd_t fd) : fd(fd) {}
31+
~socket_t();
32+
};
33+
34+
1135
// ggml_tensor is serialized into rpc_tensor
1236
struct rpc_tensor {
1337
uint64_t id;
@@ -50,7 +74,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
5074

5175
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total);
5276

53-
GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem);
77+
GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem);
5478

5579
#ifdef __cplusplus
5680
}

0 commit comments

Comments
 (0)