1
1
/*
2
- * Copyright (c) 2006, 2018 Oracle and/or its affiliates. All rights reserved.
2
+ * Copyright (c) 2006, 2019 Oracle and/or its affiliates. All rights reserved.
3
3
*
4
4
* This software is available to you under a choice of one of two
5
5
* licenses. You may choose to be licensed under the terms of the GNU
@@ -80,6 +80,17 @@ static unsigned long rds_sock_count;
80
80
static LIST_HEAD (rds_sock_list );
81
81
DECLARE_WAIT_QUEUE_HEAD (rds_poll_waitq );
82
82
83
+ /* kmem cache slab for struct rds_buf_info */
84
+ static struct kmem_cache * rds_rs_buf_info_slab ;
85
+
86
+ /* Helper function to be passed to rhashtable_free_and_destroy() to free a
87
+ * struct rs_buf_info.
88
+ */
89
+ static void rds_buf_info_free (void * rsbi , void * arg __attribute__((unused )))
90
+ {
91
+ kmem_cache_free (rds_rs_buf_info_slab , rsbi );
92
+ }
93
+
83
94
/*
84
95
* This is called as the final descriptor referencing this socket is closed.
85
96
* We have to unbind the socket so that another socket can be bound to the
@@ -112,6 +123,9 @@ static int rds_release(struct socket *sock)
112
123
rds_rdma_drop_keys (rs );
113
124
rds_notify_queue_get (rs , NULL );
114
125
126
+ rhashtable_free_and_destroy (& rs -> rs_buf_info_tbl , rds_buf_info_free ,
127
+ NULL );
128
+
115
129
spin_lock_bh (& rds_sock_lock );
116
130
list_del_init (& rs -> rs_item );
117
131
rds_sock_count -- ;
@@ -272,10 +286,18 @@ static unsigned int rds_poll(struct file *file, struct socket *sock,
272
286
if (!list_empty (& rs -> rs_recv_queue )
273
287
|| !list_empty (& rs -> rs_notify_queue ))
274
288
mask |= (POLLIN | POLLRDNORM );
275
- if (rs -> rs_snd_bytes < rds_sk_sndbuf (rs ))
276
- mask |= (POLLOUT | POLLWRNORM );
277
289
read_unlock_irqrestore (& rs -> rs_recv_lock , flags );
278
290
291
+ /* Use the number of destination this socket has to estimate the
292
+ * send buffer size. When there is no peer yet, return the default
293
+ * send buffer size.
294
+ */
295
+ spin_lock_irqsave (& rs -> rs_snd_lock , flags );
296
+ if (rs -> rs_snd_bytes < max_t (u32 , rs -> rs_buf_info_dest_cnt , 1 ) *
297
+ rds_sk_sndbuf (rs ))
298
+ mask |= (POLLOUT | POLLWRNORM );
299
+ spin_unlock_irqrestore (& rs -> rs_snd_lock , flags );
300
+
279
301
/* clear state any time we wake a seen-congested socket */
280
302
if (mask )
281
303
rs -> rs_seen_congestion = 0 ;
@@ -712,6 +734,77 @@ static int rds_getsockopt(struct socket *sock, int level, int optname,
712
734
713
735
}
714
736
737
+ /* Check if there is a rs_buf_info associated with the given address. If not,
738
+ * add one to the rds_sock. The found or added rs_buf_info is returned. If
739
+ * there is no rs_buf_info found and a new rs_buf_info cannot be allocated,
740
+ * NULL is returned and ret is set to the error. Once an address' rs_buf_info
741
+ * is added, it will not be removed until the rs_sock is closed.
742
+ */
743
+ struct rs_buf_info * rds_add_buf_info (struct rds_sock * rs , struct in6_addr * addr ,
744
+ int * ret , gfp_t gfp )
745
+ {
746
+ struct rs_buf_info * info , * tmp_info ;
747
+ unsigned long flags ;
748
+
749
+ /* Normal path, peer is expected to be found most of the time. */
750
+ info = rhashtable_lookup_fast (& rs -> rs_buf_info_tbl , addr ,
751
+ rs_buf_info_params );
752
+ if (info ) {
753
+ * ret = 0 ;
754
+ return info ;
755
+ }
756
+
757
+ /* Allocate the buffer outside of lock first. */
758
+ tmp_info = kmem_cache_alloc (rds_rs_buf_info_slab , gfp );
759
+ if (!tmp_info ) {
760
+ * ret = - ENOMEM ;
761
+ return NULL ;
762
+ }
763
+
764
+ spin_lock_irqsave (& rs -> rs_snd_lock , flags );
765
+
766
+ /* Cannot add more peer. */
767
+ if (rs -> rs_buf_info_dest_cnt + 1 > rds_sock_max_peers ) {
768
+ spin_unlock_irqrestore (& rs -> rs_snd_lock , flags );
769
+ kmem_cache_free (rds_rs_buf_info_slab , tmp_info );
770
+ * ret = - ENFILE ;
771
+ return NULL ;
772
+ }
773
+
774
+ tmp_info -> rsbi_key = * addr ;
775
+ tmp_info -> rsbi_snd_bytes = 0 ;
776
+ * ret = rhashtable_insert_fast (& rs -> rs_buf_info_tbl ,
777
+ & tmp_info -> rsbi_link , rs_buf_info_params );
778
+ if (!* ret ) {
779
+ rs -> rs_buf_info_dest_cnt ++ ;
780
+ spin_unlock_irqrestore (& rs -> rs_snd_lock , flags );
781
+ return tmp_info ;
782
+ } else if (* ret != - EEXIST ) {
783
+ spin_unlock_irqrestore (& rs -> rs_snd_lock , flags );
784
+ kmem_cache_free (rds_rs_buf_info_slab , tmp_info );
785
+ /* Very unlikely to happen... */
786
+ pr_err ("%s: cannot add rs_buf_info for %pI6c: %d\n" , __func__ ,
787
+ addr , * ret );
788
+ return NULL ;
789
+ }
790
+
791
+ /* Another thread beats us in adding the rs_buf_info.... */
792
+ info = rhashtable_lookup_fast (& rs -> rs_buf_info_tbl , addr ,
793
+ rs_buf_info_params );
794
+ spin_unlock_irqrestore (& rs -> rs_snd_lock , flags );
795
+ kmem_cache_free (rds_rs_buf_info_slab , tmp_info );
796
+
797
+ if (info ) {
798
+ * ret = 0 ;
799
+ return info ;
800
+ }
801
+
802
+ /* Should not happen... */
803
+ pr_err ("%s: cannot find rs_buf_info for %pI6c\n" , __func__ , addr );
804
+ * ret = - EINVAL ;
805
+ return NULL ;
806
+ }
807
+
715
808
static int rds_connect (struct socket * sock , struct sockaddr * uaddr ,
716
809
int addr_len , int flags )
717
810
{
@@ -800,6 +893,12 @@ static int rds_connect(struct socket *sock, struct sockaddr *uaddr,
800
893
break ;
801
894
}
802
895
896
+ if (!ret &&
897
+ !rds_add_buf_info (rs , & rs -> rs_conn_addr , & ret , GFP_KERNEL )) {
898
+ /* Need to clear the connected info in case of error. */
899
+ rs -> rs_conn_addr = in6addr_any ;
900
+ rs -> rs_conn_port = 0 ;
901
+ }
803
902
release_sock (sk );
804
903
return ret ;
805
904
}
@@ -842,6 +941,7 @@ static void rds_sock_destruct(struct sock *sk)
842
941
static int __rds_create (struct socket * sock , struct sock * sk , int protocol )
843
942
{
844
943
struct rds_sock * rs ;
944
+ int ret ;
845
945
846
946
sock_init_data (sock , sk );
847
947
sock -> ops = & rds_proto_ops ;
@@ -863,6 +963,11 @@ static int __rds_create(struct socket *sock, struct sock *sk, int protocol)
863
963
rs -> rs_netfilter_enabled = 0 ;
864
964
rs -> rs_rx_traces = 0 ;
865
965
966
+ spin_lock_init (& rs -> rs_snd_lock );
967
+ ret = rhashtable_init (& rs -> rs_buf_info_tbl , & rs_buf_info_params );
968
+ if (ret )
969
+ return ret ;
970
+
866
971
if (!ipv6_addr_any (& rs -> rs_bound_addr )) {
867
972
printk (KERN_CRIT "bound addr %pI6c at create\n" ,
868
973
& rs -> rs_bound_addr );
@@ -879,6 +984,7 @@ static int __rds_create(struct socket *sock, struct sock *sk, int protocol)
879
984
static int rds_create (struct net * net , struct socket * sock , int protocol , int kern )
880
985
{
881
986
struct sock * sk ;
987
+ int ret ;
882
988
883
989
if (sock -> type != SOCK_SEQPACKET ||
884
990
(protocol && IPPROTO_OKA != protocol ))
@@ -888,7 +994,10 @@ static int rds_create(struct net *net, struct socket *sock, int protocol, int ke
888
994
if (!sk )
889
995
return - ENOMEM ;
890
996
891
- return __rds_create (sock , sk , protocol );
997
+ ret = __rds_create (sock , sk , protocol );
998
+ if (ret )
999
+ sk_free (sk );
1000
+ return ret ;
892
1001
}
893
1002
894
1003
void debug_sock_hold (struct sock * sk )
@@ -1194,6 +1303,7 @@ static void __exit rds_exit(void)
1194
1303
rds_info_deregister_func (RDS6_INFO_SOCKETS , rds6_sock_info );
1195
1304
rds_info_deregister_func (RDS6_INFO_RECV_MESSAGES , rds6_sock_inc_info );
1196
1305
#endif
1306
+ kmem_cache_destroy (rds_rs_buf_info_slab );
1197
1307
}
1198
1308
1199
1309
module_exit (rds_exit );
@@ -1204,6 +1314,14 @@ static int __init rds_init(void)
1204
1314
{
1205
1315
int ret ;
1206
1316
1317
+ rds_rs_buf_info_slab = kmem_cache_create ("rds_rs_buf_info" ,
1318
+ sizeof (struct rs_buf_info ),
1319
+ 0 , SLAB_HWCACHE_ALIGN , NULL );
1320
+ if (!rds_rs_buf_info_slab ) {
1321
+ ret = - ENOMEM ;
1322
+ goto out ;
1323
+ }
1324
+
1207
1325
net_get_random_once (& rds_gen_num , sizeof (rds_gen_num ));
1208
1326
1209
1327
rds_bind_lock_init ();
0 commit comments