113
113
#include <trace/events/skb.h>
114
114
#include <net/busy_poll.h>
115
115
#include "udp_impl.h"
116
+ #include <net/sock_reuseport.h>
116
117
117
118
struct udp_table udp_table __read_mostly ;
118
119
EXPORT_SYMBOL (udp_table );
@@ -137,7 +138,8 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
137
138
unsigned long * bitmap ,
138
139
struct sock * sk ,
139
140
int (* saddr_comp )(const struct sock * sk1 ,
140
- const struct sock * sk2 ),
141
+ const struct sock * sk2 ,
142
+ bool match_wildcard ),
141
143
unsigned int log )
142
144
{
143
145
struct sock * sk2 ;
@@ -152,8 +154,9 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
152
154
(!sk2 -> sk_bound_dev_if || !sk -> sk_bound_dev_if ||
153
155
sk2 -> sk_bound_dev_if == sk -> sk_bound_dev_if ) &&
154
156
(!sk2 -> sk_reuseport || !sk -> sk_reuseport ||
157
+ rcu_access_pointer (sk -> sk_reuseport_cb ) ||
155
158
!uid_eq (uid , sock_i_uid (sk2 ))) &&
156
- saddr_comp (sk , sk2 )) {
159
+ saddr_comp (sk , sk2 , true )) {
157
160
if (!bitmap )
158
161
return 1 ;
159
162
__set_bit (udp_sk (sk2 )-> udp_port_hash >> log , bitmap );
@@ -170,7 +173,8 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
170
173
struct udp_hslot * hslot2 ,
171
174
struct sock * sk ,
172
175
int (* saddr_comp )(const struct sock * sk1 ,
173
- const struct sock * sk2 ))
176
+ const struct sock * sk2 ,
177
+ bool match_wildcard ))
174
178
{
175
179
struct sock * sk2 ;
176
180
struct hlist_nulls_node * node ;
@@ -186,8 +190,9 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
186
190
(!sk2 -> sk_bound_dev_if || !sk -> sk_bound_dev_if ||
187
191
sk2 -> sk_bound_dev_if == sk -> sk_bound_dev_if ) &&
188
192
(!sk2 -> sk_reuseport || !sk -> sk_reuseport ||
193
+ rcu_access_pointer (sk -> sk_reuseport_cb ) ||
189
194
!uid_eq (uid , sock_i_uid (sk2 ))) &&
190
- saddr_comp (sk , sk2 )) {
195
+ saddr_comp (sk , sk2 , true )) {
191
196
res = 1 ;
192
197
break ;
193
198
}
@@ -196,6 +201,35 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
196
201
return res ;
197
202
}
198
203
204
+ static int udp_reuseport_add_sock (struct sock * sk , struct udp_hslot * hslot ,
205
+ int (* saddr_same )(const struct sock * sk1 ,
206
+ const struct sock * sk2 ,
207
+ bool match_wildcard ))
208
+ {
209
+ struct net * net = sock_net (sk );
210
+ struct hlist_nulls_node * node ;
211
+ kuid_t uid = sock_i_uid (sk );
212
+ struct sock * sk2 ;
213
+
214
+ sk_nulls_for_each (sk2 , node , & hslot -> head ) {
215
+ if (net_eq (sock_net (sk2 ), net ) &&
216
+ sk2 != sk &&
217
+ sk2 -> sk_family == sk -> sk_family &&
218
+ ipv6_only_sock (sk2 ) == ipv6_only_sock (sk ) &&
219
+ (udp_sk (sk2 )-> udp_port_hash == udp_sk (sk )-> udp_port_hash ) &&
220
+ (sk2 -> sk_bound_dev_if == sk -> sk_bound_dev_if ) &&
221
+ sk2 -> sk_reuseport && uid_eq (uid , sock_i_uid (sk2 )) &&
222
+ (* saddr_same )(sk , sk2 , false)) {
223
+ return reuseport_add_sock (sk , sk2 );
224
+ }
225
+ }
226
+
227
+ /* Initial allocation may have already happened via setsockopt */
228
+ if (!rcu_access_pointer (sk -> sk_reuseport_cb ))
229
+ return reuseport_alloc (sk );
230
+ return 0 ;
231
+ }
232
+
199
233
/**
200
234
* udp_lib_get_port - UDP/-Lite port lookup for IPv4 and IPv6
201
235
*
@@ -207,7 +241,8 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
207
241
*/
208
242
int udp_lib_get_port (struct sock * sk , unsigned short snum ,
209
243
int (* saddr_comp )(const struct sock * sk1 ,
210
- const struct sock * sk2 ),
244
+ const struct sock * sk2 ,
245
+ bool match_wildcard ),
211
246
unsigned int hash2_nulladdr )
212
247
{
213
248
struct udp_hslot * hslot , * hslot2 ;
@@ -290,6 +325,14 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
290
325
udp_sk (sk )-> udp_port_hash = snum ;
291
326
udp_sk (sk )-> udp_portaddr_hash ^= snum ;
292
327
if (sk_unhashed (sk )) {
328
+ if (sk -> sk_reuseport &&
329
+ udp_reuseport_add_sock (sk , hslot , saddr_comp )) {
330
+ inet_sk (sk )-> inet_num = 0 ;
331
+ udp_sk (sk )-> udp_port_hash = 0 ;
332
+ udp_sk (sk )-> udp_portaddr_hash ^= snum ;
333
+ goto fail_unlock ;
334
+ }
335
+
293
336
sk_nulls_add_node_rcu (sk , & hslot -> head );
294
337
hslot -> count ++ ;
295
338
sock_prot_inuse_add (sock_net (sk ), sk -> sk_prot , 1 );
@@ -309,13 +352,22 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
309
352
}
310
353
EXPORT_SYMBOL (udp_lib_get_port );
311
354
312
- static int ipv4_rcv_saddr_equal (const struct sock * sk1 , const struct sock * sk2 )
355
+ /* match_wildcard == true: 0.0.0.0 equals to any IPv4 addresses
356
+ * match_wildcard == false: addresses must be exactly the same, i.e.
357
+ * 0.0.0.0 only equals to 0.0.0.0
358
+ */
359
+ static int ipv4_rcv_saddr_equal (const struct sock * sk1 , const struct sock * sk2 ,
360
+ bool match_wildcard )
313
361
{
314
362
struct inet_sock * inet1 = inet_sk (sk1 ), * inet2 = inet_sk (sk2 );
315
363
316
- return (!ipv6_only_sock (sk2 ) &&
317
- (!inet1 -> inet_rcv_saddr || !inet2 -> inet_rcv_saddr ||
318
- inet1 -> inet_rcv_saddr == inet2 -> inet_rcv_saddr ));
364
+ if (!ipv6_only_sock (sk2 )) {
365
+ if (inet1 -> inet_rcv_saddr == inet2 -> inet_rcv_saddr )
366
+ return 1 ;
367
+ if (!inet1 -> inet_rcv_saddr || !inet2 -> inet_rcv_saddr )
368
+ return match_wildcard ;
369
+ }
370
+ return 0 ;
319
371
}
320
372
321
373
static u32 udp4_portaddr_hash (const struct net * net , __be32 saddr ,
@@ -459,8 +511,14 @@ static struct sock *udp4_lib_lookup2(struct net *net,
459
511
badness = score ;
460
512
reuseport = sk -> sk_reuseport ;
461
513
if (reuseport ) {
514
+ struct sock * sk2 ;
462
515
hash = udp_ehashfn (net , daddr , hnum ,
463
516
saddr , sport );
517
+ sk2 = reuseport_select_sock (sk , hash );
518
+ if (sk2 ) {
519
+ result = sk2 ;
520
+ goto found ;
521
+ }
464
522
matches = 1 ;
465
523
}
466
524
} else if (score == badness && reuseport ) {
@@ -478,6 +536,7 @@ static struct sock *udp4_lib_lookup2(struct net *net,
478
536
if (get_nulls_value (node ) != slot2 )
479
537
goto begin ;
480
538
if (result ) {
539
+ found :
481
540
if (unlikely (!atomic_inc_not_zero_hint (& result -> sk_refcnt , 2 )))
482
541
result = NULL ;
483
542
else if (unlikely (compute_score2 (result , net , saddr , sport ,
@@ -540,8 +599,14 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
540
599
badness = score ;
541
600
reuseport = sk -> sk_reuseport ;
542
601
if (reuseport ) {
602
+ struct sock * sk2 ;
543
603
hash = udp_ehashfn (net , daddr , hnum ,
544
604
saddr , sport );
605
+ sk2 = reuseport_select_sock (sk , hash );
606
+ if (sk2 ) {
607
+ result = sk2 ;
608
+ goto found ;
609
+ }
545
610
matches = 1 ;
546
611
}
547
612
} else if (score == badness && reuseport ) {
@@ -560,6 +625,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
560
625
goto begin ;
561
626
562
627
if (result ) {
628
+ found :
563
629
if (unlikely (!atomic_inc_not_zero_hint (& result -> sk_refcnt , 2 )))
564
630
result = NULL ;
565
631
else if (unlikely (compute_score (result , net , saddr , hnum , sport ,
@@ -587,7 +653,8 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
587
653
struct sock * udp4_lib_lookup (struct net * net , __be32 saddr , __be16 sport ,
588
654
__be32 daddr , __be16 dport , int dif )
589
655
{
590
- return __udp4_lib_lookup (net , saddr , sport , daddr , dport , dif , & udp_table );
656
+ return __udp4_lib_lookup (net , saddr , sport , daddr , dport , dif ,
657
+ & udp_table );
591
658
}
592
659
EXPORT_SYMBOL_GPL (udp4_lib_lookup );
593
660
@@ -1398,6 +1465,8 @@ void udp_lib_unhash(struct sock *sk)
1398
1465
hslot2 = udp_hashslot2 (udptable , udp_sk (sk )-> udp_portaddr_hash );
1399
1466
1400
1467
spin_lock_bh (& hslot -> lock );
1468
+ if (rcu_access_pointer (sk -> sk_reuseport_cb ))
1469
+ reuseport_detach_sock (sk );
1401
1470
if (sk_nulls_del_node_init_rcu (sk )) {
1402
1471
hslot -> count -- ;
1403
1472
inet_sk (sk )-> inet_num = 0 ;
@@ -1425,22 +1494,28 @@ void udp_lib_rehash(struct sock *sk, u16 newhash)
1425
1494
hslot2 = udp_hashslot2 (udptable , udp_sk (sk )-> udp_portaddr_hash );
1426
1495
nhslot2 = udp_hashslot2 (udptable , newhash );
1427
1496
udp_sk (sk )-> udp_portaddr_hash = newhash ;
1428
- if (hslot2 != nhslot2 ) {
1497
+
1498
+ if (hslot2 != nhslot2 ||
1499
+ rcu_access_pointer (sk -> sk_reuseport_cb )) {
1429
1500
hslot = udp_hashslot (udptable , sock_net (sk ),
1430
1501
udp_sk (sk )-> udp_port_hash );
1431
1502
/* we must lock primary chain too */
1432
1503
spin_lock_bh (& hslot -> lock );
1433
-
1434
- spin_lock (& hslot2 -> lock );
1435
- hlist_nulls_del_init_rcu (& udp_sk (sk )-> udp_portaddr_node );
1436
- hslot2 -> count -- ;
1437
- spin_unlock (& hslot2 -> lock );
1438
-
1439
- spin_lock (& nhslot2 -> lock );
1440
- hlist_nulls_add_head_rcu (& udp_sk (sk )-> udp_portaddr_node ,
1441
- & nhslot2 -> head );
1442
- nhslot2 -> count ++ ;
1443
- spin_unlock (& nhslot2 -> lock );
1504
+ if (rcu_access_pointer (sk -> sk_reuseport_cb ))
1505
+ reuseport_detach_sock (sk );
1506
+
1507
+ if (hslot2 != nhslot2 ) {
1508
+ spin_lock (& hslot2 -> lock );
1509
+ hlist_nulls_del_init_rcu (& udp_sk (sk )-> udp_portaddr_node );
1510
+ hslot2 -> count -- ;
1511
+ spin_unlock (& hslot2 -> lock );
1512
+
1513
+ spin_lock (& nhslot2 -> lock );
1514
+ hlist_nulls_add_head_rcu (& udp_sk (sk )-> udp_portaddr_node ,
1515
+ & nhslot2 -> head );
1516
+ nhslot2 -> count ++ ;
1517
+ spin_unlock (& nhslot2 -> lock );
1518
+ }
1444
1519
1445
1520
spin_unlock_bh (& hslot -> lock );
1446
1521
}
0 commit comments