@@ -361,25 +361,15 @@ static void smc_destruct(struct sock *sk)
361
361
return ;
362
362
}
363
363
364
- static struct sock * smc_sock_alloc (struct net * net , struct socket * sock ,
365
- int protocol )
364
+ void smc_sk_init (struct net * net , struct sock * sk , int protocol )
366
365
{
367
- struct smc_sock * smc ;
368
- struct proto * prot ;
369
- struct sock * sk ;
370
-
371
- prot = (protocol == SMCPROTO_SMC6 ) ? & smc_proto6 : & smc_proto ;
372
- sk = sk_alloc (net , PF_SMC , GFP_KERNEL , prot , 0 );
373
- if (!sk )
374
- return NULL ;
366
+ struct smc_sock * smc = smc_sk (sk );
375
367
376
- sock_init_data (sock , sk ); /* sets sk_refcnt to 1 */
377
368
sk -> sk_state = SMC_INIT ;
378
369
sk -> sk_destruct = smc_destruct ;
379
370
sk -> sk_protocol = protocol ;
380
371
WRITE_ONCE (sk -> sk_sndbuf , 2 * READ_ONCE (net -> smc .sysctl_wmem ));
381
372
WRITE_ONCE (sk -> sk_rcvbuf , 2 * READ_ONCE (net -> smc .sysctl_rmem ));
382
- smc = smc_sk (sk );
383
373
INIT_WORK (& smc -> tcp_listen_work , smc_tcp_listen_work );
384
374
INIT_WORK (& smc -> connect_work , smc_connect_work );
385
375
INIT_DELAYED_WORK (& smc -> conn .tx_work , smc_tx_work );
@@ -389,6 +379,24 @@ static struct sock *smc_sock_alloc(struct net *net, struct socket *sock,
389
379
sk -> sk_prot -> hash (sk );
390
380
mutex_init (& smc -> clcsock_release_lock );
391
381
smc_init_saved_callbacks (smc );
382
+ smc -> limit_smc_hs = net -> smc .limit_smc_hs ;
383
+ smc -> use_fallback = false; /* assume rdma capability first */
384
+ smc -> fallback_rsn = 0 ;
385
+ }
386
+
387
+ static struct sock * smc_sock_alloc (struct net * net , struct socket * sock ,
388
+ int protocol )
389
+ {
390
+ struct proto * prot ;
391
+ struct sock * sk ;
392
+
393
+ prot = (protocol == SMCPROTO_SMC6 ) ? & smc_proto6 : & smc_proto ;
394
+ sk = sk_alloc (net , PF_SMC , GFP_KERNEL , prot , 0 );
395
+ if (!sk )
396
+ return NULL ;
397
+
398
+ sock_init_data (sock , sk ); /* sets sk_refcnt to 1 */
399
+ smc_sk_init (net , sk , protocol );
392
400
393
401
return sk ;
394
402
}
@@ -3303,6 +3311,31 @@ static const struct proto_ops smc_sock_ops = {
3303
3311
.splice_read = smc_splice_read ,
3304
3312
};
3305
3313
3314
+ int smc_create_clcsk (struct net * net , struct sock * sk , int family )
3315
+ {
3316
+ struct smc_sock * smc = smc_sk (sk );
3317
+ int rc ;
3318
+
3319
+ rc = sock_create_kern (net , family , SOCK_STREAM , IPPROTO_TCP ,
3320
+ & smc -> clcsock );
3321
+ if (rc ) {
3322
+ sk_common_release (sk );
3323
+ return rc ;
3324
+ }
3325
+
3326
+ /* smc_clcsock_release() does not wait smc->clcsock->sk's
3327
+ * destruction; its sk_state might not be TCP_CLOSE after
3328
+ * smc->sk is close()d, and TCP timers can be fired later,
3329
+ * which need net ref.
3330
+ */
3331
+ sk = smc -> clcsock -> sk ;
3332
+ __netns_tracker_free (net , & sk -> ns_tracker , false);
3333
+ sk -> sk_net_refcnt = 1 ;
3334
+ get_net_track (net , & sk -> ns_tracker , GFP_KERNEL );
3335
+ sock_inuse_add (net , 1 );
3336
+ return 0 ;
3337
+ }
3338
+
3306
3339
static int __smc_create (struct net * net , struct socket * sock , int protocol ,
3307
3340
int kern , struct socket * clcsock )
3308
3341
{
@@ -3328,35 +3361,12 @@ static int __smc_create(struct net *net, struct socket *sock, int protocol,
3328
3361
3329
3362
/* create internal TCP socket for CLC handshake and fallback */
3330
3363
smc = smc_sk (sk );
3331
- smc -> use_fallback = false; /* assume rdma capability first */
3332
- smc -> fallback_rsn = 0 ;
3333
-
3334
- /* default behavior from limit_smc_hs in every net namespace */
3335
- smc -> limit_smc_hs = net -> smc .limit_smc_hs ;
3336
3364
3337
3365
rc = 0 ;
3338
- if (!clcsock ) {
3339
- rc = sock_create_kern (net , family , SOCK_STREAM , IPPROTO_TCP ,
3340
- & smc -> clcsock );
3341
- if (rc ) {
3342
- sk_common_release (sk );
3343
- goto out ;
3344
- }
3345
-
3346
- /* smc_clcsock_release() does not wait smc->clcsock->sk's
3347
- * destruction; its sk_state might not be TCP_CLOSE after
3348
- * smc->sk is close()d, and TCP timers can be fired later,
3349
- * which need net ref.
3350
- */
3351
- sk = smc -> clcsock -> sk ;
3352
- __netns_tracker_free (net , & sk -> ns_tracker , false);
3353
- sk -> sk_net_refcnt = 1 ;
3354
- get_net_track (net , & sk -> ns_tracker , GFP_KERNEL );
3355
- sock_inuse_add (net , 1 );
3356
- } else {
3366
+ if (clcsock )
3357
3367
smc -> clcsock = clcsock ;
3358
- }
3359
-
3368
+ else
3369
+ rc = smc_create_clcsk ( net , sk , family );
3360
3370
out :
3361
3371
return rc ;
3362
3372
}
0 commit comments