14
14
#include <net/inet_hashtables.h>
15
15
#include <net/protocol.h>
16
16
#include <net/tcp.h>
17
+ #if IS_ENABLED (CONFIG_MPTCP_IPV6 )
18
+ #include <net/transp_v6.h>
19
+ #endif
17
20
#include <net/mptcp.h>
18
21
#include "protocol.h"
19
22
@@ -212,6 +215,90 @@ static void mptcp_close(struct sock *sk, long timeout)
212
215
sk_common_release (sk );
213
216
}
214
217
218
+ static void mptcp_copy_inaddrs (struct sock * msk , const struct sock * ssk )
219
+ {
220
+ #if IS_ENABLED (CONFIG_MPTCP_IPV6 )
221
+ const struct ipv6_pinfo * ssk6 = inet6_sk (ssk );
222
+ struct ipv6_pinfo * msk6 = inet6_sk (msk );
223
+
224
+ msk -> sk_v6_daddr = ssk -> sk_v6_daddr ;
225
+ msk -> sk_v6_rcv_saddr = ssk -> sk_v6_rcv_saddr ;
226
+
227
+ if (msk6 && ssk6 ) {
228
+ msk6 -> saddr = ssk6 -> saddr ;
229
+ msk6 -> flow_label = ssk6 -> flow_label ;
230
+ }
231
+ #endif
232
+
233
+ inet_sk (msk )-> inet_num = inet_sk (ssk )-> inet_num ;
234
+ inet_sk (msk )-> inet_dport = inet_sk (ssk )-> inet_dport ;
235
+ inet_sk (msk )-> inet_sport = inet_sk (ssk )-> inet_sport ;
236
+ inet_sk (msk )-> inet_daddr = inet_sk (ssk )-> inet_daddr ;
237
+ inet_sk (msk )-> inet_saddr = inet_sk (ssk )-> inet_saddr ;
238
+ inet_sk (msk )-> inet_rcv_saddr = inet_sk (ssk )-> inet_rcv_saddr ;
239
+ }
240
+
241
+ static struct sock * mptcp_accept (struct sock * sk , int flags , int * err ,
242
+ bool kern )
243
+ {
244
+ struct mptcp_sock * msk = mptcp_sk (sk );
245
+ struct socket * listener ;
246
+ struct sock * newsk ;
247
+
248
+ listener = __mptcp_nmpc_socket (msk );
249
+ if (WARN_ON_ONCE (!listener )) {
250
+ * err = - EINVAL ;
251
+ return NULL ;
252
+ }
253
+
254
+ pr_debug ("msk=%p, listener=%p" , msk , mptcp_subflow_ctx (listener -> sk ));
255
+ newsk = inet_csk_accept (listener -> sk , flags , err , kern );
256
+ if (!newsk )
257
+ return NULL ;
258
+
259
+ pr_debug ("msk=%p, subflow is mptcp=%d" , msk , sk_is_mptcp (newsk ));
260
+
261
+ if (sk_is_mptcp (newsk )) {
262
+ struct mptcp_subflow_context * subflow ;
263
+ struct sock * new_mptcp_sock ;
264
+ struct sock * ssk = newsk ;
265
+
266
+ subflow = mptcp_subflow_ctx (newsk );
267
+ lock_sock (sk );
268
+
269
+ local_bh_disable ();
270
+ new_mptcp_sock = sk_clone_lock (sk , GFP_ATOMIC );
271
+ if (!new_mptcp_sock ) {
272
+ * err = - ENOBUFS ;
273
+ local_bh_enable ();
274
+ release_sock (sk );
275
+ tcp_close (newsk , 0 );
276
+ return NULL ;
277
+ }
278
+
279
+ mptcp_init_sock (new_mptcp_sock );
280
+
281
+ msk = mptcp_sk (new_mptcp_sock );
282
+ msk -> remote_key = subflow -> remote_key ;
283
+ msk -> local_key = subflow -> local_key ;
284
+ msk -> subflow = NULL ;
285
+
286
+ newsk = new_mptcp_sock ;
287
+ mptcp_copy_inaddrs (newsk , ssk );
288
+ list_add (& subflow -> node , & msk -> conn_list );
289
+
290
+ /* will be fully established at mptcp_stream_accept()
291
+ * completion.
292
+ */
293
+ inet_sk_state_store (new_mptcp_sock , TCP_SYN_RECV );
294
+ bh_unlock_sock (new_mptcp_sock );
295
+ local_bh_enable ();
296
+ release_sock (sk );
297
+ }
298
+
299
+ return newsk ;
300
+ }
301
+
215
302
static int mptcp_get_port (struct sock * sk , unsigned short snum )
216
303
{
217
304
struct mptcp_sock * msk = mptcp_sk (sk );
@@ -246,12 +333,21 @@ void mptcp_finish_connect(struct sock *ssk)
246
333
WRITE_ONCE (msk -> local_key , subflow -> local_key );
247
334
}
248
335
336
+ static void mptcp_sock_graft (struct sock * sk , struct socket * parent )
337
+ {
338
+ write_lock_bh (& sk -> sk_callback_lock );
339
+ rcu_assign_pointer (sk -> sk_wq , & parent -> wq );
340
+ sk_set_socket (sk , parent );
341
+ sk -> sk_uid = SOCK_INODE (parent )-> i_uid ;
342
+ write_unlock_bh (& sk -> sk_callback_lock );
343
+ }
344
+
249
345
static struct proto mptcp_prot = {
250
346
.name = "MPTCP" ,
251
347
.owner = THIS_MODULE ,
252
348
.init = mptcp_init_sock ,
253
349
.close = mptcp_close ,
254
- .accept = inet_csk_accept ,
350
+ .accept = mptcp_accept ,
255
351
.shutdown = tcp_shutdown ,
256
352
.sendmsg = mptcp_sendmsg ,
257
353
.recvmsg = mptcp_recvmsg ,
@@ -266,10 +362,7 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
266
362
{
267
363
struct mptcp_sock * msk = mptcp_sk (sock -> sk );
268
364
struct socket * ssock ;
269
- int err = - ENOTSUPP ;
270
-
271
- if (uaddr -> sa_family != AF_INET ) // @@ allow only IPv4 for now
272
- return err ;
365
+ int err ;
273
366
274
367
lock_sock (sock -> sk );
275
368
ssock = __mptcp_socket_create (msk , MPTCP_SAME_STATE );
@@ -279,6 +372,8 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
279
372
}
280
373
281
374
err = ssock -> ops -> bind (ssock , uaddr , addr_len );
375
+ if (!err )
376
+ mptcp_copy_inaddrs (sock -> sk , ssock -> sk );
282
377
283
378
unlock :
284
379
release_sock (sock -> sk );
@@ -299,14 +394,139 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
299
394
goto unlock ;
300
395
}
301
396
397
+ #ifdef CONFIG_TCP_MD5SIG
398
+ /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
399
+ * TCP option space.
400
+ */
401
+ if (rcu_access_pointer (tcp_sk (ssock -> sk )-> md5sig_info ))
402
+ mptcp_subflow_ctx (ssock -> sk )-> request_mptcp = 0 ;
403
+ #endif
404
+
302
405
err = ssock -> ops -> connect (ssock , uaddr , addr_len , flags );
303
406
inet_sk_state_store (sock -> sk , inet_sk_state_load (ssock -> sk ));
407
+ mptcp_copy_inaddrs (sock -> sk , ssock -> sk );
304
408
305
409
unlock :
306
410
release_sock (sock -> sk );
307
411
return err ;
308
412
}
309
413
414
+ static int mptcp_v4_getname (struct socket * sock , struct sockaddr * uaddr ,
415
+ int peer )
416
+ {
417
+ if (sock -> sk -> sk_prot == & tcp_prot ) {
418
+ /* we are being invoked from __sys_accept4, after
419
+ * mptcp_accept() has just accepted a non-mp-capable
420
+ * flow: sk is a tcp_sk, not an mptcp one.
421
+ *
422
+ * Hand the socket over to tcp so all further socket ops
423
+ * bypass mptcp.
424
+ */
425
+ sock -> ops = & inet_stream_ops ;
426
+ }
427
+
428
+ return inet_getname (sock , uaddr , peer );
429
+ }
430
+
431
+ #if IS_ENABLED (CONFIG_MPTCP_IPV6 )
432
+ static int mptcp_v6_getname (struct socket * sock , struct sockaddr * uaddr ,
433
+ int peer )
434
+ {
435
+ if (sock -> sk -> sk_prot == & tcpv6_prot ) {
436
+ /* we are being invoked from __sys_accept4 after
437
+ * mptcp_accept() has accepted a non-mp-capable
438
+ * subflow: sk is a tcp_sk, not mptcp.
439
+ *
440
+ * Hand the socket over to tcp so all further
441
+ * socket ops bypass mptcp.
442
+ */
443
+ sock -> ops = & inet6_stream_ops ;
444
+ }
445
+
446
+ return inet6_getname (sock , uaddr , peer );
447
+ }
448
+ #endif
449
+
450
+ static int mptcp_listen (struct socket * sock , int backlog )
451
+ {
452
+ struct mptcp_sock * msk = mptcp_sk (sock -> sk );
453
+ struct socket * ssock ;
454
+ int err ;
455
+
456
+ pr_debug ("msk=%p" , msk );
457
+
458
+ lock_sock (sock -> sk );
459
+ ssock = __mptcp_socket_create (msk , TCP_LISTEN );
460
+ if (IS_ERR (ssock )) {
461
+ err = PTR_ERR (ssock );
462
+ goto unlock ;
463
+ }
464
+
465
+ err = ssock -> ops -> listen (ssock , backlog );
466
+ inet_sk_state_store (sock -> sk , inet_sk_state_load (ssock -> sk ));
467
+ if (!err )
468
+ mptcp_copy_inaddrs (sock -> sk , ssock -> sk );
469
+
470
+ unlock :
471
+ release_sock (sock -> sk );
472
+ return err ;
473
+ }
474
+
475
+ static bool is_tcp_proto (const struct proto * p )
476
+ {
477
+ #if IS_ENABLED (CONFIG_MPTCP_IPV6 )
478
+ return p == & tcp_prot || p == & tcpv6_prot ;
479
+ #else
480
+ return p == & tcp_prot ;
481
+ #endif
482
+ }
483
+
484
+ static int mptcp_stream_accept (struct socket * sock , struct socket * newsock ,
485
+ int flags , bool kern )
486
+ {
487
+ struct mptcp_sock * msk = mptcp_sk (sock -> sk );
488
+ struct socket * ssock ;
489
+ int err ;
490
+
491
+ pr_debug ("msk=%p" , msk );
492
+
493
+ lock_sock (sock -> sk );
494
+ if (sock -> sk -> sk_state != TCP_LISTEN )
495
+ goto unlock_fail ;
496
+
497
+ ssock = __mptcp_nmpc_socket (msk );
498
+ if (!ssock )
499
+ goto unlock_fail ;
500
+
501
+ sock_hold (ssock -> sk );
502
+ release_sock (sock -> sk );
503
+
504
+ err = ssock -> ops -> accept (sock , newsock , flags , kern );
505
+ if (err == 0 && !is_tcp_proto (newsock -> sk -> sk_prot )) {
506
+ struct mptcp_sock * msk = mptcp_sk (newsock -> sk );
507
+ struct mptcp_subflow_context * subflow ;
508
+
509
+ /* set ssk->sk_socket of accept()ed flows to mptcp socket.
510
+ * This is needed so NOSPACE flag can be set from tcp stack.
511
+ */
512
+ list_for_each_entry (subflow , & msk -> conn_list , node ) {
513
+ struct sock * ssk = mptcp_subflow_tcp_sock (subflow );
514
+
515
+ if (!ssk -> sk_socket )
516
+ mptcp_sock_graft (ssk , newsock );
517
+ }
518
+
519
+ inet_sk_state_store (newsock -> sk , TCP_ESTABLISHED );
520
+ }
521
+
522
+ sock_put (ssock -> sk );
523
+ return err ;
524
+
525
+ unlock_fail :
526
+ release_sock (sock -> sk );
527
+ return - EINVAL ;
528
+ }
529
+
310
530
static __poll_t mptcp_poll (struct file * file , struct socket * sock ,
311
531
struct poll_table_struct * wait )
312
532
{
@@ -332,6 +552,9 @@ void __init mptcp_init(void)
332
552
mptcp_stream_ops .bind = mptcp_bind ;
333
553
mptcp_stream_ops .connect = mptcp_stream_connect ;
334
554
mptcp_stream_ops .poll = mptcp_poll ;
555
+ mptcp_stream_ops .accept = mptcp_stream_accept ;
556
+ mptcp_stream_ops .getname = mptcp_v4_getname ;
557
+ mptcp_stream_ops .listen = mptcp_listen ;
335
558
336
559
mptcp_subflow_init ();
337
560
@@ -371,6 +594,9 @@ int mptcpv6_init(void)
371
594
mptcp_v6_stream_ops .bind = mptcp_bind ;
372
595
mptcp_v6_stream_ops .connect = mptcp_stream_connect ;
373
596
mptcp_v6_stream_ops .poll = mptcp_poll ;
597
+ mptcp_v6_stream_ops .accept = mptcp_stream_accept ;
598
+ mptcp_v6_stream_ops .getname = mptcp_v6_getname ;
599
+ mptcp_v6_stream_ops .listen = mptcp_listen ;
374
600
375
601
err = inet6_register_protosw (& mptcp_v6_protosw );
376
602
if (err )
0 commit comments