16
16
#include <linux/inet_diag.h>
17
17
#include <linux/sock_diag.h>
18
18
19
- static const struct sock_diag_handler * sock_diag_handlers [AF_MAX ];
20
- static int (* inet_rcv_compat )(struct sk_buff * skb , struct nlmsghdr * nlh );
21
- static DEFINE_MUTEX (sock_diag_table_mutex );
19
+ static const struct sock_diag_handler __rcu * sock_diag_handlers [AF_MAX ];
20
+
21
+ static struct sock_diag_inet_compat __rcu * inet_rcv_compat ;
22
+
22
23
static struct workqueue_struct * broadcast_wq ;
23
24
24
25
DEFINE_COOKIE (sock_cookie );
@@ -122,6 +123,24 @@ static size_t sock_diag_nlmsg_size(void)
122
123
+ nla_total_size_64bit (sizeof (struct tcp_info ))); /* INET_DIAG_INFO */
123
124
}
124
125
126
+ static const struct sock_diag_handler * sock_diag_lock_handler (int family )
127
+ {
128
+ const struct sock_diag_handler * handler ;
129
+
130
+ rcu_read_lock ();
131
+ handler = rcu_dereference (sock_diag_handlers [family ]);
132
+ if (handler && !try_module_get (handler -> owner ))
133
+ handler = NULL ;
134
+ rcu_read_unlock ();
135
+
136
+ return handler ;
137
+ }
138
+
139
+ static void sock_diag_unlock_handler (const struct sock_diag_handler * handler )
140
+ {
141
+ module_put (handler -> owner );
142
+ }
143
+
125
144
static void sock_diag_broadcast_destroy_work (struct work_struct * work )
126
145
{
127
146
struct broadcast_sk * bsk =
@@ -138,12 +157,12 @@ static void sock_diag_broadcast_destroy_work(struct work_struct *work)
138
157
if (!skb )
139
158
goto out ;
140
159
141
- mutex_lock ( & sock_diag_table_mutex );
142
- hndl = sock_diag_handlers [ sk -> sk_family ];
143
- if (hndl && hndl -> get_info )
144
- err = hndl -> get_info (skb , sk );
145
- mutex_unlock ( & sock_diag_table_mutex );
146
-
160
+ hndl = sock_diag_lock_handler ( sk -> sk_family );
161
+ if ( hndl ) {
162
+ if (hndl -> get_info )
163
+ err = hndl -> get_info (skb , sk );
164
+ sock_diag_unlock_handler ( hndl );
165
+ }
147
166
if (!err )
148
167
nlmsg_multicast (sock_net (sk )-> diag_nlsk , skb , 0 , group ,
149
168
GFP_KERNEL );
@@ -166,51 +185,45 @@ void sock_diag_broadcast_destroy(struct sock *sk)
166
185
queue_work (broadcast_wq , & bsk -> work );
167
186
}
168
187
169
- void sock_diag_register_inet_compat (int ( * fn )( struct sk_buff * skb , struct nlmsghdr * nlh ) )
188
+ void sock_diag_register_inet_compat (const struct sock_diag_inet_compat * ptr )
170
189
{
171
- mutex_lock (& sock_diag_table_mutex );
172
- inet_rcv_compat = fn ;
173
- mutex_unlock (& sock_diag_table_mutex );
190
+ xchg ((__force const struct sock_diag_inet_compat * * )& inet_rcv_compat ,
191
+ ptr );
174
192
}
175
193
EXPORT_SYMBOL_GPL (sock_diag_register_inet_compat );
176
194
177
- void sock_diag_unregister_inet_compat (int ( * fn )( struct sk_buff * skb , struct nlmsghdr * nlh ) )
195
+ void sock_diag_unregister_inet_compat (const struct sock_diag_inet_compat * ptr )
178
196
{
179
- mutex_lock (& sock_diag_table_mutex );
180
- inet_rcv_compat = NULL ;
181
- mutex_unlock (& sock_diag_table_mutex );
197
+ const struct sock_diag_inet_compat * old ;
198
+
199
+ old = xchg ((__force const struct sock_diag_inet_compat * * )& inet_rcv_compat ,
200
+ NULL );
201
+ WARN_ON_ONCE (old != ptr );
182
202
}
183
203
EXPORT_SYMBOL_GPL (sock_diag_unregister_inet_compat );
184
204
185
205
int sock_diag_register (const struct sock_diag_handler * hndl )
186
206
{
187
- int err = 0 ;
207
+ int family = hndl -> family ;
188
208
189
- if (hndl -> family >= AF_MAX )
209
+ if (family >= AF_MAX )
190
210
return - EINVAL ;
191
211
192
- mutex_lock (& sock_diag_table_mutex );
193
- if (sock_diag_handlers [hndl -> family ])
194
- err = - EBUSY ;
195
- else
196
- sock_diag_handlers [hndl -> family ] = hndl ;
197
- mutex_unlock (& sock_diag_table_mutex );
198
-
199
- return err ;
212
+ return !cmpxchg ((const struct sock_diag_handler * * )
213
+ & sock_diag_handlers [family ],
214
+ NULL , hndl ) ? 0 : - EBUSY ;
200
215
}
201
216
EXPORT_SYMBOL_GPL (sock_diag_register );
202
217
203
- void sock_diag_unregister (const struct sock_diag_handler * hnld )
218
+ void sock_diag_unregister (const struct sock_diag_handler * hndl )
204
219
{
205
- int family = hnld -> family ;
220
+ int family = hndl -> family ;
206
221
207
222
if (family >= AF_MAX )
208
223
return ;
209
224
210
- mutex_lock (& sock_diag_table_mutex );
211
- BUG_ON (sock_diag_handlers [family ] != hnld );
212
- sock_diag_handlers [family ] = NULL ;
213
- mutex_unlock (& sock_diag_table_mutex );
225
+ xchg ((const struct sock_diag_handler * * )& sock_diag_handlers [family ],
226
+ NULL );
214
227
}
215
228
EXPORT_SYMBOL_GPL (sock_diag_unregister );
216
229
@@ -227,41 +240,48 @@ static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh)
227
240
return - EINVAL ;
228
241
req -> sdiag_family = array_index_nospec (req -> sdiag_family , AF_MAX );
229
242
230
- if (sock_diag_handlers [req -> sdiag_family ] == NULL )
243
+ if (! rcu_access_pointer ( sock_diag_handlers [req -> sdiag_family ]) )
231
244
sock_load_diag_module (req -> sdiag_family , 0 );
232
245
233
- mutex_lock (& sock_diag_table_mutex );
234
- hndl = sock_diag_handlers [req -> sdiag_family ];
246
+ hndl = sock_diag_lock_handler (req -> sdiag_family );
235
247
if (hndl == NULL )
236
- err = - ENOENT ;
237
- else if (nlh -> nlmsg_type == SOCK_DIAG_BY_FAMILY )
248
+ return - ENOENT ;
249
+
250
+ if (nlh -> nlmsg_type == SOCK_DIAG_BY_FAMILY )
238
251
err = hndl -> dump (skb , nlh );
239
252
else if (nlh -> nlmsg_type == SOCK_DESTROY && hndl -> destroy )
240
253
err = hndl -> destroy (skb , nlh );
241
254
else
242
255
err = - EOPNOTSUPP ;
243
- mutex_unlock ( & sock_diag_table_mutex );
256
+ sock_diag_unlock_handler ( hndl );
244
257
245
258
return err ;
246
259
}
247
260
248
261
static int sock_diag_rcv_msg (struct sk_buff * skb , struct nlmsghdr * nlh ,
249
262
struct netlink_ext_ack * extack )
250
263
{
264
+ const struct sock_diag_inet_compat * ptr ;
251
265
int ret ;
252
266
253
267
switch (nlh -> nlmsg_type ) {
254
268
case TCPDIAG_GETSOCK :
255
269
case DCCPDIAG_GETSOCK :
256
- if (inet_rcv_compat == NULL )
270
+
271
+ if (!rcu_access_pointer (inet_rcv_compat ))
257
272
sock_load_diag_module (AF_INET , 0 );
258
273
259
- mutex_lock (& sock_diag_table_mutex );
260
- if (inet_rcv_compat != NULL )
261
- ret = inet_rcv_compat (skb , nlh );
262
- else
263
- ret = - EOPNOTSUPP ;
264
- mutex_unlock (& sock_diag_table_mutex );
274
+ rcu_read_lock ();
275
+ ptr = rcu_dereference (inet_rcv_compat );
276
+ if (ptr && !try_module_get (ptr -> owner ))
277
+ ptr = NULL ;
278
+ rcu_read_unlock ();
279
+
280
+ ret = - EOPNOTSUPP ;
281
+ if (ptr ) {
282
+ ret = ptr -> fn (skb , nlh );
283
+ module_put (ptr -> owner );
284
+ }
265
285
266
286
return ret ;
267
287
case SOCK_DIAG_BY_FAMILY :
@@ -272,26 +292,22 @@ static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
272
292
}
273
293
}
274
294
275
- static DEFINE_MUTEX (sock_diag_mutex );
276
-
277
295
static void sock_diag_rcv (struct sk_buff * skb )
278
296
{
279
- mutex_lock (& sock_diag_mutex );
280
297
netlink_rcv_skb (skb , & sock_diag_rcv_msg );
281
- mutex_unlock (& sock_diag_mutex );
282
298
}
283
299
284
300
static int sock_diag_bind (struct net * net , int group )
285
301
{
286
302
switch (group ) {
287
303
case SKNLGRP_INET_TCP_DESTROY :
288
304
case SKNLGRP_INET_UDP_DESTROY :
289
- if (!sock_diag_handlers [AF_INET ])
305
+ if (!rcu_access_pointer ( sock_diag_handlers [AF_INET ]) )
290
306
sock_load_diag_module (AF_INET , 0 );
291
307
break ;
292
308
case SKNLGRP_INET6_TCP_DESTROY :
293
309
case SKNLGRP_INET6_UDP_DESTROY :
294
- if (!sock_diag_handlers [AF_INET6 ])
310
+ if (!rcu_access_pointer ( sock_diag_handlers [AF_INET6 ]) )
295
311
sock_load_diag_module (AF_INET6 , 0 );
296
312
break ;
297
313
}
0 commit comments