2
2
#include " ggml.h"
3
3
#include " ggml-backend-impl.h"
4
4
5
- #include < memory>
6
5
#include < string>
7
6
#include < vector>
8
7
#include < unordered_map>
9
8
#include < unordered_set>
10
- #include < sys/socket.h>
11
- #include < sys/types.h>
12
- #include < netinet/in.h>
13
- #include < netinet/tcp.h>
14
- #include < netdb.h>
9
+ #ifndef _WIN32
10
+ # include < sys/socket.h>
11
+ # include < sys/types.h>
12
+ # include < netinet/in.h>
13
+ # include < netinet/tcp.h>
14
+ # include < netdb.h>
15
+ # include < unistd.h>
16
+ #endif
15
17
#include < string.h>
16
- #include < unistd.h>
17
18
18
19
#define UNUSED GGML_UNUSED
19
20
24
25
#define GGML_PRINT_DEBUG (...)
25
26
#endif
26
27
27
- // RPC data structures
28
+ #ifdef _WIN32
29
+ using ssize_t = __int64;
30
+ #endif
28
31
29
- struct sockfd {
30
- int fd;
31
- sockfd (int fd) : fd(fd) {}
32
- ~sockfd () {
33
- close (fd);
34
- }
35
- };
32
+ // RPC data structures
36
33
37
34
static ggml_guid_t ggml_backend_rpc_guid () {
38
35
static ggml_guid guid = {0x99 , 0x68 , 0x5b , 0x6c , 0xd2 , 0x83 , 0x3d , 0x24 , 0x25 , 0x36 , 0x72 , 0xe1 , 0x5b , 0x0e , 0x14 , 0x03 };
39
36
return &guid;
40
37
}
41
38
42
39
struct ggml_backend_rpc_buffer_type_context {
43
- std::shared_ptr<sockfd > sock;
40
+ std::shared_ptr<socket_t > sock;
44
41
std::string name;
45
42
size_t alignment;
46
43
size_t max_size;
@@ -49,27 +46,47 @@ struct ggml_backend_rpc_buffer_type_context {
49
46
struct ggml_backend_rpc_context {
50
47
std::string endpoint;
51
48
std::string name;
52
- std::shared_ptr<sockfd > sock;
49
+ std::shared_ptr<socket_t > sock;
53
50
ggml_backend_buffer_type_t buft;
54
51
};
55
52
56
53
struct ggml_backend_rpc_buffer_context {
57
- std::shared_ptr<sockfd > sock;
54
+ std::shared_ptr<socket_t > sock;
58
55
std::unordered_map<ggml_backend_buffer_t , void *> base_cache;
59
56
uint64_t remote_ptr;
60
57
std::string name;
61
58
};
62
59
63
-
64
60
// RPC helper functions
65
61
66
- static std::shared_ptr<sockfd> socket_connect (const char * host, int port) {
62
+ socket_t ::~socket_t () {
63
+ #ifdef _WIN32
64
+ closesocket (this ->fd );
65
+ #else
66
+ close (this ->fd );
67
+ #endif
68
+ }
69
+
70
+ static std::shared_ptr<socket_t > make_socket (sockfd_t fd) {
71
+ #ifdef _WIN32
72
+ if (fd == INVALID_SOCKET) {
73
+ return nullptr ;
74
+ }
75
+ #else
76
+ if (fd < 0 ) {
77
+ return nullptr ;
78
+ }
79
+ #endif
80
+ return std::make_shared<socket_t >(fd);
81
+ }
82
+
83
+ static std::shared_ptr<socket_t > socket_connect (const char * host, int port) {
67
84
struct sockaddr_in addr;
68
- int sock = socket (AF_INET, SOCK_STREAM, 0 );
69
- if (sock < 0 ) {
85
+ auto sockfd = socket (AF_INET, SOCK_STREAM, 0 );
86
+ auto sock_ptr = make_socket (sockfd);
87
+ if (sock_ptr == nullptr ) {
70
88
return nullptr ;
71
89
}
72
- auto sock_ptr = std::make_shared<sockfd>(sock);
73
90
// set TCP_NODELAY to disable Nagle's algorithm
74
91
int flag = 1 ;
75
92
int ret = setsockopt (sock_ptr->fd , IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof (int ));
@@ -83,17 +100,17 @@ static std::shared_ptr<sockfd> socket_connect(const char * host, int port) {
83
100
fprintf (stderr, " Cannot resolve host '%s'\n " , host);
84
101
return nullptr ;
85
102
}
86
- bcopy (( char *)server-> h_addr , ( char *) &addr.sin_addr .s_addr , server->h_length );
103
+ memcpy ( &addr.sin_addr .s_addr , server-> h_addr , server->h_length );
87
104
if (connect (sock_ptr->fd , (struct sockaddr *)&addr, sizeof (addr)) < 0 ) {
88
105
return nullptr ;
89
106
}
90
107
return sock_ptr;
91
108
}
92
109
93
- static bool send_data (int sockfd, const void * data, size_t size) {
110
+ static bool send_data (sockfd_t sockfd, const void * data, size_t size) {
94
111
size_t bytes_sent = 0 ;
95
112
while (bytes_sent < size) {
96
- ssize_t n = send (sockfd, (const uint8_t *)data + bytes_sent, size - bytes_sent, 0 );
113
+ ssize_t n = send (sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0 );
97
114
if (n < 0 ) {
98
115
return false ;
99
116
}
@@ -102,10 +119,10 @@ static bool send_data(int sockfd, const void * data, size_t size) {
102
119
return true ;
103
120
}
104
121
105
- static bool recv_data (int sockfd, void * data, size_t size) {
122
+ static bool recv_data (sockfd_t sockfd, void * data, size_t size) {
106
123
size_t bytes_recv = 0 ;
107
124
while (bytes_recv < size) {
108
- ssize_t n = recv (sockfd, (uint8_t *)data + bytes_recv, size - bytes_recv, 0 );
125
+ ssize_t n = recv (sockfd, (char *)data + bytes_recv, size - bytes_recv, 0 );
109
126
if (n <= 0 ) {
110
127
return false ;
111
128
}
@@ -116,7 +133,7 @@ static bool recv_data(int sockfd, void * data, size_t size) {
116
133
117
134
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
118
135
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
119
- static bool send_rpc_cmd (const std::shared_ptr<sockfd > & sock, enum rpc_cmd cmd, const std::vector<uint8_t > & input, std::vector<uint8_t > & output) {
136
+ static bool send_rpc_cmd (const std::shared_ptr<socket_t > & sock, enum rpc_cmd cmd, const std::vector<uint8_t > & input, std::vector<uint8_t > & output) {
120
137
uint8_t cmd_byte = cmd;
121
138
if (!send_data (sock->fd , &cmd_byte, sizeof (cmd_byte))) {
122
139
return false ;
@@ -348,7 +365,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
348
365
return buffer;
349
366
}
350
367
351
- static size_t get_alignment (const std::shared_ptr<sockfd > & sock) {
368
+ static size_t get_alignment (const std::shared_ptr<socket_t > & sock) {
352
369
// input serialization format: | 0 bytes |
353
370
std::vector<uint8_t > input;
354
371
std::vector<uint8_t > output;
@@ -366,7 +383,7 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_
366
383
return buft_ctx->alignment ;
367
384
}
368
385
369
- static size_t get_max_size (const std::shared_ptr<sockfd > & sock) {
386
+ static size_t get_max_size (const std::shared_ptr<socket_t > & sock) {
370
387
// input serialization format: | 0 bytes |
371
388
std::vector<uint8_t > input;
372
389
std::vector<uint8_t > output;
@@ -522,6 +539,15 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
522
539
if (instances.find (endpoint) != instances.end ()) {
523
540
return instances[endpoint];
524
541
}
542
+ #ifdef _WIN32
543
+ {
544
+ WSADATA wsaData;
545
+ int res = WSAStartup (MAKEWORD (2 , 2 ), &wsaData);
546
+ if (res != 0 ) {
547
+ return nullptr ;
548
+ }
549
+ }
550
+ #endif
525
551
GGML_PRINT_DEBUG (" Connecting to %s\n " , endpoint.c_str ());
526
552
// split the endpoint into host and port
527
553
size_t pos = endpoint.find (" :" );
@@ -565,7 +591,7 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
565
591
return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_rpc_guid ());
566
592
}
567
593
568
- static void get_device_memory (const std::shared_ptr<sockfd > & sock, size_t * free, size_t * total) {
594
+ static void get_device_memory (const std::shared_ptr<socket_t > & sock, size_t * free, size_t * total) {
569
595
// input serialization format: | 0 bytes |
570
596
std::vector<uint8_t > input;
571
597
std::vector<uint8_t > output;
@@ -779,7 +805,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
779
805
ggml_free (ctx);
780
806
}
781
807
782
- void rpc_serve_client (ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem) {
808
+ void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
783
809
while (true ) {
784
810
uint8_t cmd;
785
811
if (!recv_data (sockfd, &cmd, 1 )) {
0 commit comments