Skip to content

Commit cf7da0d

Browse files
Peter Krystaddavem330
authored andcommitted
mptcp: Create SUBFLOW socket for incoming connections
Add subflow_request_sock type that extends tcp_request_sock and add an is_mptcp flag to tcp_request_sock distinguish them. Override the listen() and accept() methods of the MPTCP socket proto_ops so they may act on the subflow socket. Override the conn_request() and syn_recv_sock() handlers in the inet_connection_sock to handle incoming MPTCP SYNs and the ACK to the response SYN. Add handling in tcp_output.c to add MP_CAPABLE to an outgoing SYN-ACK response for a subflow_request_sock. Co-developed-by: Davide Caratti <[email protected]> Signed-off-by: Davide Caratti <[email protected]> Co-developed-by: Florian Westphal <[email protected]> Signed-off-by: Florian Westphal <[email protected]> Co-developed-by: Matthieu Baerts <[email protected]> Signed-off-by: Matthieu Baerts <[email protected]> Co-developed-by: Paolo Abeni <[email protected]> Signed-off-by: Paolo Abeni <[email protected]> Signed-off-by: Peter Krystad <[email protected]> Signed-off-by: Christoph Paasch <[email protected]> Signed-off-by: David S. Miller <[email protected]>
1 parent cec37a6 commit cf7da0d

File tree

1 file changed

+231
-5
lines changed

1 file changed

+231
-5
lines changed

net/mptcp/protocol.c

Lines changed: 231 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#include <net/inet_hashtables.h>
1515
#include <net/protocol.h>
1616
#include <net/tcp.h>
17+
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
18+
#include <net/transp_v6.h>
19+
#endif
1720
#include <net/mptcp.h>
1821
#include "protocol.h"
1922

@@ -212,6 +215,90 @@ static void mptcp_close(struct sock *sk, long timeout)
212215
sk_common_release(sk);
213216
}
214217

218+
static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
219+
{
220+
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
221+
const struct ipv6_pinfo *ssk6 = inet6_sk(ssk);
222+
struct ipv6_pinfo *msk6 = inet6_sk(msk);
223+
224+
msk->sk_v6_daddr = ssk->sk_v6_daddr;
225+
msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr;
226+
227+
if (msk6 && ssk6) {
228+
msk6->saddr = ssk6->saddr;
229+
msk6->flow_label = ssk6->flow_label;
230+
}
231+
#endif
232+
233+
inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num;
234+
inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport;
235+
inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport;
236+
inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr;
237+
inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr;
238+
inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr;
239+
}
240+
241+
static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
242+
bool kern)
243+
{
244+
struct mptcp_sock *msk = mptcp_sk(sk);
245+
struct socket *listener;
246+
struct sock *newsk;
247+
248+
listener = __mptcp_nmpc_socket(msk);
249+
if (WARN_ON_ONCE(!listener)) {
250+
*err = -EINVAL;
251+
return NULL;
252+
}
253+
254+
pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk));
255+
newsk = inet_csk_accept(listener->sk, flags, err, kern);
256+
if (!newsk)
257+
return NULL;
258+
259+
pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
260+
261+
if (sk_is_mptcp(newsk)) {
262+
struct mptcp_subflow_context *subflow;
263+
struct sock *new_mptcp_sock;
264+
struct sock *ssk = newsk;
265+
266+
subflow = mptcp_subflow_ctx(newsk);
267+
lock_sock(sk);
268+
269+
local_bh_disable();
270+
new_mptcp_sock = sk_clone_lock(sk, GFP_ATOMIC);
271+
if (!new_mptcp_sock) {
272+
*err = -ENOBUFS;
273+
local_bh_enable();
274+
release_sock(sk);
275+
tcp_close(newsk, 0);
276+
return NULL;
277+
}
278+
279+
mptcp_init_sock(new_mptcp_sock);
280+
281+
msk = mptcp_sk(new_mptcp_sock);
282+
msk->remote_key = subflow->remote_key;
283+
msk->local_key = subflow->local_key;
284+
msk->subflow = NULL;
285+
286+
newsk = new_mptcp_sock;
287+
mptcp_copy_inaddrs(newsk, ssk);
288+
list_add(&subflow->node, &msk->conn_list);
289+
290+
/* will be fully established at mptcp_stream_accept()
291+
* completion.
292+
*/
293+
inet_sk_state_store(new_mptcp_sock, TCP_SYN_RECV);
294+
bh_unlock_sock(new_mptcp_sock);
295+
local_bh_enable();
296+
release_sock(sk);
297+
}
298+
299+
return newsk;
300+
}
301+
215302
static int mptcp_get_port(struct sock *sk, unsigned short snum)
216303
{
217304
struct mptcp_sock *msk = mptcp_sk(sk);
@@ -246,12 +333,21 @@ void mptcp_finish_connect(struct sock *ssk)
246333
WRITE_ONCE(msk->local_key, subflow->local_key);
247334
}
248335

336+
static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
337+
{
338+
write_lock_bh(&sk->sk_callback_lock);
339+
rcu_assign_pointer(sk->sk_wq, &parent->wq);
340+
sk_set_socket(sk, parent);
341+
sk->sk_uid = SOCK_INODE(parent)->i_uid;
342+
write_unlock_bh(&sk->sk_callback_lock);
343+
}
344+
249345
static struct proto mptcp_prot = {
250346
.name = "MPTCP",
251347
.owner = THIS_MODULE,
252348
.init = mptcp_init_sock,
253349
.close = mptcp_close,
254-
.accept = inet_csk_accept,
350+
.accept = mptcp_accept,
255351
.shutdown = tcp_shutdown,
256352
.sendmsg = mptcp_sendmsg,
257353
.recvmsg = mptcp_recvmsg,
@@ -266,10 +362,7 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
266362
{
267363
struct mptcp_sock *msk = mptcp_sk(sock->sk);
268364
struct socket *ssock;
269-
int err = -ENOTSUPP;
270-
271-
if (uaddr->sa_family != AF_INET) // @@ allow only IPv4 for now
272-
return err;
365+
int err;
273366

274367
lock_sock(sock->sk);
275368
ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE);
@@ -279,6 +372,8 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
279372
}
280373

281374
err = ssock->ops->bind(ssock, uaddr, addr_len);
375+
if (!err)
376+
mptcp_copy_inaddrs(sock->sk, ssock->sk);
282377

283378
unlock:
284379
release_sock(sock->sk);
@@ -299,14 +394,139 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
299394
goto unlock;
300395
}
301396

397+
#ifdef CONFIG_TCP_MD5SIG
398+
/* no MPTCP if MD5SIG is enabled on this socket or we may run out of
399+
* TCP option space.
400+
*/
401+
if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
402+
mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0;
403+
#endif
404+
302405
err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
303406
inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
407+
mptcp_copy_inaddrs(sock->sk, ssock->sk);
304408

305409
unlock:
306410
release_sock(sock->sk);
307411
return err;
308412
}
309413

414+
static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
415+
int peer)
416+
{
417+
if (sock->sk->sk_prot == &tcp_prot) {
418+
/* we are being invoked from __sys_accept4, after
419+
* mptcp_accept() has just accepted a non-mp-capable
420+
* flow: sk is a tcp_sk, not an mptcp one.
421+
*
422+
* Hand the socket over to tcp so all further socket ops
423+
* bypass mptcp.
424+
*/
425+
sock->ops = &inet_stream_ops;
426+
}
427+
428+
return inet_getname(sock, uaddr, peer);
429+
}
430+
431+
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
432+
static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
433+
int peer)
434+
{
435+
if (sock->sk->sk_prot == &tcpv6_prot) {
436+
/* we are being invoked from __sys_accept4 after
437+
* mptcp_accept() has accepted a non-mp-capable
438+
* subflow: sk is a tcp_sk, not mptcp.
439+
*
440+
* Hand the socket over to tcp so all further
441+
* socket ops bypass mptcp.
442+
*/
443+
sock->ops = &inet6_stream_ops;
444+
}
445+
446+
return inet6_getname(sock, uaddr, peer);
447+
}
448+
#endif
449+
450+
static int mptcp_listen(struct socket *sock, int backlog)
451+
{
452+
struct mptcp_sock *msk = mptcp_sk(sock->sk);
453+
struct socket *ssock;
454+
int err;
455+
456+
pr_debug("msk=%p", msk);
457+
458+
lock_sock(sock->sk);
459+
ssock = __mptcp_socket_create(msk, TCP_LISTEN);
460+
if (IS_ERR(ssock)) {
461+
err = PTR_ERR(ssock);
462+
goto unlock;
463+
}
464+
465+
err = ssock->ops->listen(ssock, backlog);
466+
inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
467+
if (!err)
468+
mptcp_copy_inaddrs(sock->sk, ssock->sk);
469+
470+
unlock:
471+
release_sock(sock->sk);
472+
return err;
473+
}
474+
475+
static bool is_tcp_proto(const struct proto *p)
476+
{
477+
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
478+
return p == &tcp_prot || p == &tcpv6_prot;
479+
#else
480+
return p == &tcp_prot;
481+
#endif
482+
}
483+
484+
static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
485+
int flags, bool kern)
486+
{
487+
struct mptcp_sock *msk = mptcp_sk(sock->sk);
488+
struct socket *ssock;
489+
int err;
490+
491+
pr_debug("msk=%p", msk);
492+
493+
lock_sock(sock->sk);
494+
if (sock->sk->sk_state != TCP_LISTEN)
495+
goto unlock_fail;
496+
497+
ssock = __mptcp_nmpc_socket(msk);
498+
if (!ssock)
499+
goto unlock_fail;
500+
501+
sock_hold(ssock->sk);
502+
release_sock(sock->sk);
503+
504+
err = ssock->ops->accept(sock, newsock, flags, kern);
505+
if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
506+
struct mptcp_sock *msk = mptcp_sk(newsock->sk);
507+
struct mptcp_subflow_context *subflow;
508+
509+
/* set ssk->sk_socket of accept()ed flows to mptcp socket.
510+
* This is needed so NOSPACE flag can be set from tcp stack.
511+
*/
512+
list_for_each_entry(subflow, &msk->conn_list, node) {
513+
struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
514+
515+
if (!ssk->sk_socket)
516+
mptcp_sock_graft(ssk, newsock);
517+
}
518+
519+
inet_sk_state_store(newsock->sk, TCP_ESTABLISHED);
520+
}
521+
522+
sock_put(ssock->sk);
523+
return err;
524+
525+
unlock_fail:
526+
release_sock(sock->sk);
527+
return -EINVAL;
528+
}
529+
310530
static __poll_t mptcp_poll(struct file *file, struct socket *sock,
311531
struct poll_table_struct *wait)
312532
{
@@ -332,6 +552,9 @@ void __init mptcp_init(void)
332552
mptcp_stream_ops.bind = mptcp_bind;
333553
mptcp_stream_ops.connect = mptcp_stream_connect;
334554
mptcp_stream_ops.poll = mptcp_poll;
555+
mptcp_stream_ops.accept = mptcp_stream_accept;
556+
mptcp_stream_ops.getname = mptcp_v4_getname;
557+
mptcp_stream_ops.listen = mptcp_listen;
335558

336559
mptcp_subflow_init();
337560

@@ -371,6 +594,9 @@ int mptcpv6_init(void)
371594
mptcp_v6_stream_ops.bind = mptcp_bind;
372595
mptcp_v6_stream_ops.connect = mptcp_stream_connect;
373596
mptcp_v6_stream_ops.poll = mptcp_poll;
597+
mptcp_v6_stream_ops.accept = mptcp_stream_accept;
598+
mptcp_v6_stream_ops.getname = mptcp_v6_getname;
599+
mptcp_v6_stream_ops.listen = mptcp_listen;
374600

375601
err = inet6_register_protosw(&mptcp_v6_protosw);
376602
if (err)

0 commit comments

Comments
 (0)