Skip to content

Commit 4638de5

Browse files
kmaloordavem330
authored andcommitted
mptcp: handle local addrs announced by userspace PMs
This change adds an internal function to store/retrieve local addrs announced by userspace PM implementations to/from its kernel context. The function addresses the requirements of three scenarios: 1) ADD_ADDR announcements (which require that a local id be provided), 2) retrieving the local id associated with an address, and also where one may need to be assigned, and 3) reissuance of ADD_ADDRs when there's a successful match of addr/id. The list of all stored local addr entries is held under the MPTCP sock structure. Memory for these entries is allocated from the sock option buffer, so the list of addrs is bounded by optmem_max. The list if not released via REMOVE_ADDR signals is ultimately freed when the sock is destructed. Acked-by: Paolo Abeni <[email protected]> Signed-off-by: Kishen Maloor <[email protected]> Signed-off-by: Mat Martineau <[email protected]> Signed-off-by: David S. Miller <[email protected]>
1 parent f43f0cd commit 4638de5

File tree

6 files changed

+113
-26
lines changed

6 files changed

+113
-26
lines changed

net/mptcp/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
obj-$(CONFIG_MPTCP) += mptcp.o
33

44
mptcp-y := protocol.o subflow.o options.o token.o crypto.o ctrl.o pm.o diag.o \
5-
mib.o pm_netlink.o sockopt.o
5+
mib.o pm_netlink.o sockopt.o pm_userspace.o
66

77
obj-$(CONFIG_SYN_COOKIES) += syncookies.o
88
obj-$(CONFIG_INET_MPTCP_DIAG) += mptcp_diag.o

net/mptcp/pm.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ void mptcp_pm_data_init(struct mptcp_sock *msk)
469469
{
470470
spin_lock_init(&msk->pm.lock);
471471
INIT_LIST_HEAD(&msk->pm.anno_list);
472+
INIT_LIST_HEAD(&msk->pm.userspace_pm_local_addr_list);
472473
mptcp_pm_data_reset(msk);
473474
}
474475

net/mptcp/pm_netlink.c

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,6 @@ static struct genl_family mptcp_genl_family;
2222

2323
static int pm_nl_pernet_id;
2424

25-
struct mptcp_pm_addr_entry {
26-
struct list_head list;
27-
struct mptcp_addr_info addr;
28-
u8 flags;
29-
int ifindex;
30-
struct socket *lsk;
31-
};
32-
3325
struct mptcp_pm_add_entry {
3426
struct list_head list;
3527
struct mptcp_addr_info addr;
@@ -66,8 +58,8 @@ pm_nl_get_pernet_from_msk(const struct mptcp_sock *msk)
6658
return pm_nl_get_pernet(sock_net((struct sock *)msk));
6759
}
6860

69-
static bool addresses_equal(const struct mptcp_addr_info *a,
70-
const struct mptcp_addr_info *b, bool use_port)
61+
bool mptcp_addresses_equal(const struct mptcp_addr_info *a,
62+
const struct mptcp_addr_info *b, bool use_port)
7163
{
7264
bool addr_equals = false;
7365

@@ -131,7 +123,7 @@ static bool lookup_subflow_by_saddr(const struct list_head *list,
131123
skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
132124

133125
local_address(skc, &cur);
134-
if (addresses_equal(&cur, saddr, saddr->port))
126+
if (mptcp_addresses_equal(&cur, saddr, saddr->port))
135127
return true;
136128
}
137129

@@ -149,7 +141,7 @@ static bool lookup_subflow_by_daddr(const struct list_head *list,
149141
skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
150142

151143
remote_address(skc, &cur);
152-
if (addresses_equal(&cur, daddr, daddr->port))
144+
if (mptcp_addresses_equal(&cur, daddr, daddr->port))
153145
return true;
154146
}
155147

@@ -269,7 +261,7 @@ mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk,
269261
lockdep_assert_held(&msk->pm.lock);
270262

271263
list_for_each_entry(entry, &msk->pm.anno_list, list) {
272-
if (addresses_equal(&entry->addr, addr, true))
264+
if (mptcp_addresses_equal(&entry->addr, addr, true))
273265
return entry;
274266
}
275267

@@ -286,7 +278,7 @@ bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
286278

287279
spin_lock_bh(&msk->pm.lock);
288280
list_for_each_entry(entry, &msk->pm.anno_list, list) {
289-
if (addresses_equal(&entry->addr, &saddr, true)) {
281+
if (mptcp_addresses_equal(&entry->addr, &saddr, true)) {
290282
ret = true;
291283
goto out;
292284
}
@@ -421,7 +413,7 @@ static bool lookup_address_in_vec(const struct mptcp_addr_info *addrs, unsigned
421413
int i;
422414

423415
for (i = 0; i < nr; i++) {
424-
if (addresses_equal(&addrs[i], addr, addr->port))
416+
if (mptcp_addresses_equal(&addrs[i], addr, addr->port))
425417
return true;
426418
}
427419

@@ -457,7 +449,7 @@ static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullm
457449
mptcp_for_each_subflow(msk, subflow) {
458450
ssk = mptcp_subflow_tcp_sock(subflow);
459451
remote_address((struct sock_common *)ssk, &addrs[i]);
460-
if (deny_id0 && addresses_equal(&addrs[i], &remote, false))
452+
if (deny_id0 && mptcp_addresses_equal(&addrs[i], &remote, false))
461453
continue;
462454

463455
if (!lookup_address_in_vec(addrs, i, &addrs[i]) &&
@@ -490,7 +482,7 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info,
490482
struct mptcp_pm_addr_entry *entry;
491483

492484
list_for_each_entry(entry, &pernet->local_addr_list, list) {
493-
if ((!lookup_by_id && addresses_equal(&entry->addr, info, true)) ||
485+
if ((!lookup_by_id && mptcp_addresses_equal(&entry->addr, info, true)) ||
494486
(lookup_by_id && entry->addr.id == info->id))
495487
return entry;
496488
}
@@ -505,7 +497,7 @@ lookup_id_by_addr(const struct pm_nl_pernet *pernet, const struct mptcp_addr_inf
505497

506498
rcu_read_lock();
507499
list_for_each_entry(entry, &pernet->local_addr_list, list) {
508-
if (addresses_equal(&entry->addr, addr, entry->addr.port)) {
500+
if (mptcp_addresses_equal(&entry->addr, addr, entry->addr.port)) {
509501
ret = entry->addr.id;
510502
break;
511503
}
@@ -739,7 +731,7 @@ static int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
739731
struct mptcp_addr_info local;
740732

741733
local_address((struct sock_common *)ssk, &local);
742-
if (!addresses_equal(&local, addr, addr->port))
734+
if (!mptcp_addresses_equal(&local, addr, addr->port))
743735
continue;
744736

745737
if (subflow->backup != bkup)
@@ -909,9 +901,9 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
909901
* singled addresses
910902
*/
911903
list_for_each_entry(cur, &pernet->local_addr_list, list) {
912-
if (addresses_equal(&cur->addr, &entry->addr,
913-
address_use_port(entry) &&
914-
address_use_port(cur))) {
904+
if (mptcp_addresses_equal(&cur->addr, &entry->addr,
905+
address_use_port(entry) &&
906+
address_use_port(cur))) {
915907
/* allow replacing the exiting endpoint only if such
916908
* endpoint is an implicit one and the user-space
917909
* did not provide an endpoint id
@@ -1038,14 +1030,14 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
10381030
*/
10391031
local_address((struct sock_common *)msk, &msk_local);
10401032
local_address((struct sock_common *)skc, &skc_local);
1041-
if (addresses_equal(&msk_local, &skc_local, false))
1033+
if (mptcp_addresses_equal(&msk_local, &skc_local, false))
10421034
return 0;
10431035

10441036
pernet = pm_nl_get_pernet_from_msk(msk);
10451037

10461038
rcu_read_lock();
10471039
list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
1048-
if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
1040+
if (mptcp_addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
10491041
ret = entry->addr.id;
10501042
break;
10511043
}
@@ -1416,7 +1408,7 @@ static int mptcp_nl_remove_id_zero_address(struct net *net,
14161408
goto next;
14171409

14181410
local_address((struct sock_common *)msk, &msk_local);
1419-
if (!addresses_equal(&msk_local, addr, addr->port))
1411+
if (!mptcp_addresses_equal(&msk_local, addr, addr->port))
14201412
goto next;
14211413

14221414
lock_sock(sk);

net/mptcp/pm_userspace.c

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// SPDX-License-Identifier: GPL-2.0
2+
/* Multipath TCP
3+
*
4+
* Copyright (c) 2022, Intel Corporation.
5+
*/
6+
7+
#include "protocol.h"
8+
9+
void mptcp_free_local_addr_list(struct mptcp_sock *msk)
10+
{
11+
struct mptcp_pm_addr_entry *entry, *tmp;
12+
struct sock *sk = (struct sock *)msk;
13+
LIST_HEAD(free_list);
14+
15+
if (!mptcp_pm_is_userspace(msk))
16+
return;
17+
18+
spin_lock_bh(&msk->pm.lock);
19+
list_splice_init(&msk->pm.userspace_pm_local_addr_list, &free_list);
20+
spin_unlock_bh(&msk->pm.lock);
21+
22+
list_for_each_entry_safe(entry, tmp, &free_list, list) {
23+
sock_kfree_s(sk, entry, sizeof(*entry));
24+
}
25+
}
26+
27+
int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk,
28+
struct mptcp_pm_addr_entry *entry)
29+
{
30+
DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
31+
struct mptcp_pm_addr_entry *match = NULL;
32+
struct sock *sk = (struct sock *)msk;
33+
struct mptcp_pm_addr_entry *e;
34+
bool addr_match = false;
35+
bool id_match = false;
36+
int ret = -EINVAL;
37+
38+
bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
39+
40+
spin_lock_bh(&msk->pm.lock);
41+
list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) {
42+
addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true);
43+
if (addr_match && entry->addr.id == 0)
44+
entry->addr.id = e->addr.id;
45+
id_match = (e->addr.id == entry->addr.id);
46+
if (addr_match && id_match) {
47+
match = e;
48+
break;
49+
} else if (addr_match || id_match) {
50+
break;
51+
}
52+
__set_bit(e->addr.id, id_bitmap);
53+
}
54+
55+
if (!match && !addr_match && !id_match) {
56+
/* Memory for the entry is allocated from the
57+
* sock option buffer.
58+
*/
59+
e = sock_kmalloc(sk, sizeof(*e), GFP_ATOMIC);
60+
if (!e) {
61+
spin_unlock_bh(&msk->pm.lock);
62+
return -ENOMEM;
63+
}
64+
65+
*e = *entry;
66+
if (!e->addr.id)
67+
e->addr.id = find_next_zero_bit(id_bitmap,
68+
MPTCP_PM_MAX_ADDR_ID + 1,
69+
1);
70+
list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list);
71+
ret = e->addr.id;
72+
} else if (match) {
73+
ret = entry->addr.id;
74+
}
75+
76+
spin_unlock_bh(&msk->pm.lock);
77+
return ret;
78+
}

net/mptcp/protocol.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3097,6 +3097,7 @@ void mptcp_destroy_common(struct mptcp_sock *msk)
30973097
msk->rmem_fwd_alloc = 0;
30983098
mptcp_token_destroy(msk);
30993099
mptcp_pm_free_anno_list(msk);
3100+
mptcp_free_local_addr_list(msk);
31003101
}
31013102

31023103
static void mptcp_destroy(struct sock *sk)

net/mptcp/protocol.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ struct mptcp_pm_data {
208208
struct mptcp_addr_info local;
209209
struct mptcp_addr_info remote;
210210
struct list_head anno_list;
211+
struct list_head userspace_pm_local_addr_list;
211212

212213
spinlock_t lock; /*protects the whole PM data */
213214

@@ -228,6 +229,14 @@ struct mptcp_pm_data {
228229
struct mptcp_rm_list rm_list_rx;
229230
};
230231

232+
struct mptcp_pm_addr_entry {
233+
struct list_head list;
234+
struct mptcp_addr_info addr;
235+
u8 flags;
236+
int ifindex;
237+
struct socket *lsk;
238+
};
239+
231240
struct mptcp_data_frag {
232241
struct list_head list;
233242
u64 data_seq;
@@ -601,6 +610,9 @@ void mptcp_subflow_reset(struct sock *ssk);
601610
void mptcp_sock_graft(struct sock *sk, struct socket *parent);
602611
struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk);
603612

613+
bool mptcp_addresses_equal(const struct mptcp_addr_info *a,
614+
const struct mptcp_addr_info *b, bool use_port);
615+
604616
/* called with sk socket lock held */
605617
int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc,
606618
const struct mptcp_addr_info *remote);
@@ -779,6 +791,9 @@ int mptcp_pm_announce_addr(struct mptcp_sock *msk,
779791
bool echo);
780792
int mptcp_pm_remove_addr(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list);
781793
int mptcp_pm_remove_subflow(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list);
794+
int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk,
795+
struct mptcp_pm_addr_entry *entry);
796+
void mptcp_free_local_addr_list(struct mptcp_sock *msk);
782797

783798
void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
784799
const struct sock *ssk, gfp_t gfp);

0 commit comments

Comments
 (0)