@@ -873,6 +873,81 @@ static bool mptcp_frag_can_collapse_to(const struct mptcp_sock *msk,
873
873
df -> data_seq + df -> data_len == msk -> write_seq ;
874
874
}
875
875
876
+ static int mptcp_wmem_with_overhead (int size )
877
+ {
878
+ return size + ((sizeof (struct mptcp_data_frag ) * size ) >> PAGE_SHIFT );
879
+ }
880
+
881
+ static void __mptcp_wmem_reserve (struct sock * sk , int size )
882
+ {
883
+ int amount = mptcp_wmem_with_overhead (size );
884
+ struct mptcp_sock * msk = mptcp_sk (sk );
885
+
886
+ WARN_ON_ONCE (msk -> wmem_reserved );
887
+ if (amount <= sk -> sk_forward_alloc )
888
+ goto reserve ;
889
+
890
+ /* under memory pressure try to reserve at most a single page
891
+ * otherwise try to reserve the full estimate and fallback
892
+ * to a single page before entering the error path
893
+ */
894
+ if ((tcp_under_memory_pressure (sk ) && amount > PAGE_SIZE ) ||
895
+ !sk_wmem_schedule (sk , amount )) {
896
+ if (amount <= PAGE_SIZE )
897
+ goto nomem ;
898
+
899
+ amount = PAGE_SIZE ;
900
+ if (!sk_wmem_schedule (sk , amount ))
901
+ goto nomem ;
902
+ }
903
+
904
+ reserve :
905
+ msk -> wmem_reserved = amount ;
906
+ sk -> sk_forward_alloc -= amount ;
907
+ return ;
908
+
909
+ nomem :
910
+ /* we will wait for memory on next allocation */
911
+ msk -> wmem_reserved = -1 ;
912
+ }
913
+
914
+ static void __mptcp_update_wmem (struct sock * sk )
915
+ {
916
+ struct mptcp_sock * msk = mptcp_sk (sk );
917
+
918
+ if (!msk -> wmem_reserved )
919
+ return ;
920
+
921
+ if (msk -> wmem_reserved < 0 )
922
+ msk -> wmem_reserved = 0 ;
923
+ if (msk -> wmem_reserved > 0 ) {
924
+ sk -> sk_forward_alloc += msk -> wmem_reserved ;
925
+ msk -> wmem_reserved = 0 ;
926
+ }
927
+ }
928
+
929
+ static bool mptcp_wmem_alloc (struct sock * sk , int size )
930
+ {
931
+ struct mptcp_sock * msk = mptcp_sk (sk );
932
+
933
+ /* check for pre-existing error condition */
934
+ if (msk -> wmem_reserved < 0 )
935
+ return false;
936
+
937
+ if (msk -> wmem_reserved >= size )
938
+ goto account ;
939
+
940
+ if (!sk_wmem_schedule (sk , size ))
941
+ return false;
942
+
943
+ sk -> sk_forward_alloc -= size ;
944
+ msk -> wmem_reserved += size ;
945
+
946
+ account :
947
+ msk -> wmem_reserved -= size ;
948
+ return true;
949
+ }
950
+
876
951
static void dfrag_uncharge (struct sock * sk , int len )
877
952
{
878
953
sk_mem_uncharge (sk , len );
@@ -930,7 +1005,7 @@ static void mptcp_clean_una(struct sock *sk)
930
1005
}
931
1006
932
1007
out :
933
- if (cleaned )
1008
+ if (cleaned && tcp_under_memory_pressure ( sk ) )
934
1009
sk_mem_reclaim_partial (sk );
935
1010
}
936
1011
@@ -1307,7 +1382,7 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
1307
1382
if (msg -> msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL ))
1308
1383
return - EOPNOTSUPP ;
1309
1384
1310
- lock_sock (sk );
1385
+ mptcp_lock_sock (sk , __mptcp_wmem_reserve ( sk , len ) );
1311
1386
1312
1387
timeo = sock_sndtimeo (sk , msg -> msg_flags & MSG_DONTWAIT );
1313
1388
@@ -1356,11 +1431,12 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
1356
1431
offset = dfrag -> offset + dfrag -> data_len ;
1357
1432
psize = pfrag -> size - offset ;
1358
1433
psize = min_t (size_t , psize , msg_data_left (msg ));
1359
- if (!sk_wmem_schedule (sk , psize + frag_truesize ))
1434
+ if (!mptcp_wmem_alloc (sk , psize + frag_truesize ))
1360
1435
goto wait_for_memory ;
1361
1436
1362
1437
if (copy_page_from_iter (dfrag -> page , offset , psize ,
1363
1438
& msg -> msg_iter ) != psize ) {
1439
+ msk -> wmem_reserved += psize + frag_truesize ;
1364
1440
ret = - EFAULT ;
1365
1441
goto out ;
1366
1442
}
@@ -1376,7 +1452,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
1376
1452
* Note: we charge such data both to sk and ssk
1377
1453
*/
1378
1454
sk_wmem_queued_add (sk , frag_truesize );
1379
- sk -> sk_forward_alloc -= frag_truesize ;
1380
1455
if (!dfrag_collapsed ) {
1381
1456
get_page (dfrag -> page );
1382
1457
list_add_tail (& dfrag -> list , & msk -> rtx_queue );
@@ -2003,6 +2078,7 @@ static int __mptcp_init_sock(struct sock *sk)
2003
2078
INIT_WORK (& msk -> work , mptcp_worker );
2004
2079
msk -> out_of_order_queue = RB_ROOT ;
2005
2080
msk -> first_pending = NULL ;
2081
+ msk -> wmem_reserved = 0 ;
2006
2082
2007
2083
msk -> ack_hint = NULL ;
2008
2084
msk -> first = NULL ;
@@ -2197,6 +2273,7 @@ static void __mptcp_destroy_sock(struct sock *sk)
2197
2273
2198
2274
sk -> sk_prot -> destroy (sk );
2199
2275
2276
+ WARN_ON_ONCE (msk -> wmem_reserved );
2200
2277
sk_stream_kill_queues (sk );
2201
2278
xfrm_sk_free_policy (sk );
2202
2279
sk_refcnt_debug_release (sk );
@@ -2542,13 +2619,14 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname,
2542
2619
2543
2620
#define MPTCP_DEFERRED_ALL (TCPF_WRITE_TIMER_DEFERRED)
2544
2621
2545
- /* this is very alike tcp_release_cb() but we must handle differently a
2546
- * different set of events
2547
- */
2622
+ /* processes deferred events and flush wmem */
2548
2623
static void mptcp_release_cb (struct sock * sk )
2549
2624
{
2550
2625
unsigned long flags , nflags ;
2551
2626
2627
+ /* clear any wmem reservation and errors */
2628
+ __mptcp_update_wmem (sk );
2629
+
2552
2630
do {
2553
2631
flags = sk -> sk_tsq_flags ;
2554
2632
if (!(flags & MPTCP_DEFERRED_ALL ))
0 commit comments