Skip to content

Commit 419ce13

Browse files
Paolo Abenikuba-moo
authored andcommitted
tcp: allow again tcp_disconnect() when threads are waiting
As reported by Tom, .NET and applications build on top of it rely on connect(AF_UNSPEC) to async cancel pending I/O operations on TCP socket. The blamed commit below caused a regression, as such cancellation can now fail. As suggested by Eric, this change addresses the problem explicitly causing blocking I/O operation to terminate immediately (with an error) when a concurrent disconnect() is executed. Instead of tracking the number of threads blocked on a given socket, track the number of disconnect() issued on such socket. If such counter changes after a blocking operation releasing and re-acquiring the socket lock, error out the current operation. Fixes: 4faeee0 ("tcp: deny tcp_disconnect() when threads are waiting") Reported-by: Tom Deseyn <[email protected]> Closes: https://bugzilla.redhat.com/show_bug.cgi?id=1886305 Suggested-by: Eric Dumazet <[email protected]> Signed-off-by: Paolo Abeni <[email protected]> Reviewed-by: Eric Dumazet <[email protected]> Link: https://lore.kernel.org/r/f3b95e47e3dbed840960548aebaa8d954372db41.1697008693.git.pabeni@redhat.com Signed-off-by: Jakub Kicinski <[email protected]>
1 parent 242e345 commit 419ce13

File tree

10 files changed

+80
-45
lines changed

10 files changed

+80
-45
lines changed

drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ static int csk_wait_memory(struct chtls_dev *cdev,
911911
struct sock *sk, long *timeo_p)
912912
{
913913
DEFINE_WAIT_FUNC(wait, woken_wake_function);
914-
int err = 0;
914+
int ret, err = 0;
915915
long current_timeo;
916916
long vm_wait = 0;
917917
bool noblock;
@@ -942,10 +942,13 @@ static int csk_wait_memory(struct chtls_dev *cdev,
942942

943943
set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
944944
sk->sk_write_pending++;
945-
sk_wait_event(sk, &current_timeo, sk->sk_err ||
946-
(sk->sk_shutdown & SEND_SHUTDOWN) ||
947-
(csk_mem_free(cdev, sk) && !vm_wait), &wait);
945+
ret = sk_wait_event(sk, &current_timeo, sk->sk_err ||
946+
(sk->sk_shutdown & SEND_SHUTDOWN) ||
947+
(csk_mem_free(cdev, sk) && !vm_wait),
948+
&wait);
948949
sk->sk_write_pending--;
950+
if (ret < 0)
951+
goto do_error;
949952

950953
if (vm_wait) {
951954
vm_wait -= current_timeo;
@@ -1348,6 +1351,7 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
13481351
int copied = 0;
13491352
int target;
13501353
long timeo;
1354+
int ret;
13511355

13521356
buffers_freed = 0;
13531357

@@ -1423,7 +1427,11 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
14231427
if (copied >= target)
14241428
break;
14251429
chtls_cleanup_rbuf(sk, copied);
1426-
sk_wait_data(sk, &timeo, NULL);
1430+
ret = sk_wait_data(sk, &timeo, NULL);
1431+
if (ret < 0) {
1432+
copied = copied ? : ret;
1433+
goto unlock;
1434+
}
14271435
continue;
14281436
found_ok_skb:
14291437
if (!skb->len) {
@@ -1518,6 +1526,8 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
15181526

15191527
if (buffers_freed)
15201528
chtls_cleanup_rbuf(sk, copied);
1529+
1530+
unlock:
15211531
release_sock(sk);
15221532
return copied;
15231533
}
@@ -1534,6 +1544,7 @@ static int peekmsg(struct sock *sk, struct msghdr *msg,
15341544
int copied = 0;
15351545
size_t avail; /* amount of available data in current skb */
15361546
long timeo;
1547+
int ret;
15371548

15381549
lock_sock(sk);
15391550
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
@@ -1585,7 +1596,12 @@ static int peekmsg(struct sock *sk, struct msghdr *msg,
15851596
release_sock(sk);
15861597
lock_sock(sk);
15871598
} else {
1588-
sk_wait_data(sk, &timeo, NULL);
1599+
ret = sk_wait_data(sk, &timeo, NULL);
1600+
if (ret < 0) {
1601+
/* here 'copied' is 0 due to previous checks */
1602+
copied = ret;
1603+
break;
1604+
}
15891605
}
15901606

15911607
if (unlikely(peek_seq != tp->copied_seq)) {
@@ -1656,6 +1672,7 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
16561672
int copied = 0;
16571673
long timeo;
16581674
int target; /* Read at least this many bytes */
1675+
int ret;
16591676

16601677
buffers_freed = 0;
16611678

@@ -1747,7 +1764,11 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
17471764
if (copied >= target)
17481765
break;
17491766
chtls_cleanup_rbuf(sk, copied);
1750-
sk_wait_data(sk, &timeo, NULL);
1767+
ret = sk_wait_data(sk, &timeo, NULL);
1768+
if (ret < 0) {
1769+
copied = copied ? : ret;
1770+
goto unlock;
1771+
}
17511772
continue;
17521773

17531774
found_ok_skb:
@@ -1816,6 +1837,7 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
18161837
if (buffers_freed)
18171838
chtls_cleanup_rbuf(sk, copied);
18181839

1840+
unlock:
18191841
release_sock(sk);
18201842
return copied;
18211843
}

include/net/sock.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ struct sk_filter;
336336
* @sk_cgrp_data: cgroup data for this cgroup
337337
* @sk_memcg: this socket's memory cgroup association
338338
* @sk_write_pending: a write to stream socket waits to start
339-
* @sk_wait_pending: number of threads blocked on this socket
339+
* @sk_disconnects: number of disconnect operations performed on this sock
340340
* @sk_state_change: callback to indicate change in the state of the sock
341341
* @sk_data_ready: callback to indicate there is data to be processed
342342
* @sk_write_space: callback to indicate there is bf sending space available
@@ -429,7 +429,7 @@ struct sock {
429429
unsigned int sk_napi_id;
430430
#endif
431431
int sk_rcvbuf;
432-
int sk_wait_pending;
432+
int sk_disconnects;
433433

434434
struct sk_filter __rcu *sk_filter;
435435
union {
@@ -1189,8 +1189,7 @@ static inline void sock_rps_reset_rxhash(struct sock *sk)
11891189
}
11901190

11911191
#define sk_wait_event(__sk, __timeo, __condition, __wait) \
1192-
({ int __rc; \
1193-
__sk->sk_wait_pending++; \
1192+
({ int __rc, __dis = __sk->sk_disconnects; \
11941193
release_sock(__sk); \
11951194
__rc = __condition; \
11961195
if (!__rc) { \
@@ -1200,8 +1199,7 @@ static inline void sock_rps_reset_rxhash(struct sock *sk)
12001199
} \
12011200
sched_annotate_sleep(); \
12021201
lock_sock(__sk); \
1203-
__sk->sk_wait_pending--; \
1204-
__rc = __condition; \
1202+
__rc = __dis == __sk->sk_disconnects ? __condition : -EPIPE; \
12051203
__rc; \
12061204
})
12071205

net/core/stream.c

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ EXPORT_SYMBOL(sk_stream_wait_close);
117117
*/
118118
int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
119119
{
120-
int err = 0;
120+
int ret, err = 0;
121121
long vm_wait = 0;
122122
long current_timeo = *timeo_p;
123123
DEFINE_WAIT_FUNC(wait, woken_wake_function);
@@ -142,11 +142,13 @@ int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
142142

143143
set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
144144
sk->sk_write_pending++;
145-
sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) ||
146-
(READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) ||
147-
(sk_stream_memory_free(sk) &&
148-
!vm_wait), &wait);
145+
ret = sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) ||
146+
(READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) ||
147+
(sk_stream_memory_free(sk) && !vm_wait),
148+
&wait);
149149
sk->sk_write_pending--;
150+
if (ret < 0)
151+
goto do_error;
150152

151153
if (vm_wait) {
152154
vm_wait -= current_timeo;

net/ipv4/af_inet.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)
597597

598598
add_wait_queue(sk_sleep(sk), &wait);
599599
sk->sk_write_pending += writebias;
600-
sk->sk_wait_pending++;
601600

602601
/* Basic assumption: if someone sets sk->sk_err, he _must_
603602
* change state of the socket from TCP_SYN_*.
@@ -613,7 +612,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)
613612
}
614613
remove_wait_queue(sk_sleep(sk), &wait);
615614
sk->sk_write_pending -= writebias;
616-
sk->sk_wait_pending--;
617615
return timeo;
618616
}
619617

@@ -642,6 +640,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
642640
return -EINVAL;
643641

644642
if (uaddr->sa_family == AF_UNSPEC) {
643+
sk->sk_disconnects++;
645644
err = sk->sk_prot->disconnect(sk, flags);
646645
sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED;
647646
goto out;
@@ -696,6 +695,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
696695
int writebias = (sk->sk_protocol == IPPROTO_TCP) &&
697696
tcp_sk(sk)->fastopen_req &&
698697
tcp_sk(sk)->fastopen_req->data ? 1 : 0;
698+
int dis = sk->sk_disconnects;
699699

700700
/* Error code is set above */
701701
if (!timeo || !inet_wait_for_connect(sk, timeo, writebias))
@@ -704,6 +704,11 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
704704
err = sock_intr_errno(timeo);
705705
if (signal_pending(current))
706706
goto out;
707+
708+
if (dis != sk->sk_disconnects) {
709+
err = -EPIPE;
710+
goto out;
711+
}
707712
}
708713

709714
/* Connection was closed by RST, timeout, ICMP error
@@ -725,6 +730,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
725730
sock_error:
726731
err = sock_error(sk) ? : -ECONNABORTED;
727732
sock->state = SS_UNCONNECTED;
733+
sk->sk_disconnects++;
728734
if (sk->sk_prot->disconnect(sk, flags))
729735
sock->state = SS_DISCONNECTING;
730736
goto out;

net/ipv4/inet_connection_sock.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,6 @@ struct sock *inet_csk_clone_lock(const struct sock *sk,
11451145
if (newsk) {
11461146
struct inet_connection_sock *newicsk = inet_csk(newsk);
11471147

1148-
newsk->sk_wait_pending = 0;
11491148
inet_sk_set_state(newsk, TCP_SYN_RECV);
11501149
newicsk->icsk_bind_hash = NULL;
11511150
newicsk->icsk_bind2_hash = NULL;

net/ipv4/tcp.c

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,9 @@ ssize_t tcp_splice_read(struct socket *sock, loff_t *ppos,
831831
*/
832832
if (!skb_queue_empty(&sk->sk_receive_queue))
833833
break;
834-
sk_wait_data(sk, &timeo, NULL);
834+
ret = sk_wait_data(sk, &timeo, NULL);
835+
if (ret < 0)
836+
break;
835837
if (signal_pending(current)) {
836838
ret = sock_intr_errno(timeo);
837839
break;
@@ -2442,7 +2444,11 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
24422444
__sk_flush_backlog(sk);
24432445
} else {
24442446
tcp_cleanup_rbuf(sk, copied);
2445-
sk_wait_data(sk, &timeo, last);
2447+
err = sk_wait_data(sk, &timeo, last);
2448+
if (err < 0) {
2449+
err = copied ? : err;
2450+
goto out;
2451+
}
24462452
}
24472453

24482454
if ((flags & MSG_PEEK) &&
@@ -2966,12 +2972,6 @@ int tcp_disconnect(struct sock *sk, int flags)
29662972
int old_state = sk->sk_state;
29672973
u32 seq;
29682974

2969-
/* Deny disconnect if other threads are blocked in sk_wait_event()
2970-
* or inet_wait_for_connect().
2971-
*/
2972-
if (sk->sk_wait_pending)
2973-
return -EBUSY;
2974-
29752975
if (old_state != TCP_CLOSE)
29762976
tcp_set_state(sk, TCP_CLOSE);
29772977

net/ipv4/tcp_bpf.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ static int tcp_bpf_recvmsg_parser(struct sock *sk,
307307
}
308308

309309
data = tcp_msg_wait_data(sk, psock, timeo);
310+
if (data < 0)
311+
return data;
310312
if (data && !sk_psock_queue_empty(psock))
311313
goto msg_bytes_ready;
312314
copied = -EAGAIN;
@@ -351,6 +353,8 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
351353

352354
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
353355
data = tcp_msg_wait_data(sk, psock, timeo);
356+
if (data < 0)
357+
return data;
354358
if (data) {
355359
if (!sk_psock_queue_empty(psock))
356360
goto msg_bytes_ready;

net/mptcp/protocol.c

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3098,12 +3098,6 @@ static int mptcp_disconnect(struct sock *sk, int flags)
30983098
{
30993099
struct mptcp_sock *msk = mptcp_sk(sk);
31003100

3101-
/* Deny disconnect if other threads are blocked in sk_wait_event()
3102-
* or inet_wait_for_connect().
3103-
*/
3104-
if (sk->sk_wait_pending)
3105-
return -EBUSY;
3106-
31073101
/* We are on the fastopen error path. We can't call straight into the
31083102
* subflows cleanup code due to lock nesting (we are already under
31093103
* msk->firstsocket lock).
@@ -3173,7 +3167,6 @@ struct sock *mptcp_sk_clone_init(const struct sock *sk,
31733167
inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
31743168
#endif
31753169

3176-
nsk->sk_wait_pending = 0;
31773170
__mptcp_init_sock(nsk);
31783171

31793172
msk = mptcp_sk(nsk);

net/tls/tls_main.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx)
139139

140140
int wait_on_pending_writer(struct sock *sk, long *timeo)
141141
{
142-
int rc = 0;
143142
DEFINE_WAIT_FUNC(wait, woken_wake_function);
143+
int ret, rc = 0;
144144

145145
add_wait_queue(sk_sleep(sk), &wait);
146146
while (1) {
@@ -154,9 +154,13 @@ int wait_on_pending_writer(struct sock *sk, long *timeo)
154154
break;
155155
}
156156

157-
if (sk_wait_event(sk, timeo,
158-
!READ_ONCE(sk->sk_write_pending), &wait))
157+
ret = sk_wait_event(sk, timeo,
158+
!READ_ONCE(sk->sk_write_pending), &wait);
159+
if (ret) {
160+
if (ret < 0)
161+
rc = ret;
159162
break;
163+
}
160164
}
161165
remove_wait_queue(sk_sleep(sk), &wait);
162166
return rc;

net/tls/tls_sw.c

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
12911291
struct tls_context *tls_ctx = tls_get_ctx(sk);
12921292
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
12931293
DEFINE_WAIT_FUNC(wait, woken_wake_function);
1294+
int ret = 0;
12941295
long timeo;
12951296

12961297
timeo = sock_rcvtimeo(sk, nonblock);
@@ -1302,6 +1303,9 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
13021303
if (sk->sk_err)
13031304
return sock_error(sk);
13041305

1306+
if (ret < 0)
1307+
return ret;
1308+
13051309
if (!skb_queue_empty(&sk->sk_receive_queue)) {
13061310
tls_strp_check_rcv(&ctx->strp);
13071311
if (tls_strp_msg_ready(ctx))
@@ -1320,10 +1324,10 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
13201324
released = true;
13211325
add_wait_queue(sk_sleep(sk), &wait);
13221326
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1323-
sk_wait_event(sk, &timeo,
1324-
tls_strp_msg_ready(ctx) ||
1325-
!sk_psock_queue_empty(psock),
1326-
&wait);
1327+
ret = sk_wait_event(sk, &timeo,
1328+
tls_strp_msg_ready(ctx) ||
1329+
!sk_psock_queue_empty(psock),
1330+
&wait);
13271331
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
13281332
remove_wait_queue(sk_sleep(sk), &wait);
13291333

@@ -1852,6 +1856,7 @@ static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx,
18521856
bool nonblock)
18531857
{
18541858
long timeo;
1859+
int ret;
18551860

18561861
timeo = sock_rcvtimeo(sk, nonblock);
18571862

@@ -1861,14 +1866,16 @@ static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx,
18611866
ctx->reader_contended = 1;
18621867

18631868
add_wait_queue(&ctx->wq, &wait);
1864-
sk_wait_event(sk, &timeo,
1865-
!READ_ONCE(ctx->reader_present), &wait);
1869+
ret = sk_wait_event(sk, &timeo,
1870+
!READ_ONCE(ctx->reader_present), &wait);
18661871
remove_wait_queue(&ctx->wq, &wait);
18671872

18681873
if (timeo <= 0)
18691874
return -EAGAIN;
18701875
if (signal_pending(current))
18711876
return sock_intr_errno(timeo);
1877+
if (ret < 0)
1878+
return ret;
18721879
}
18731880

18741881
WRITE_ONCE(ctx->reader_present, 1);

0 commit comments

Comments
 (0)