@@ -162,57 +162,69 @@ pub fn test_socketpair() {
162
162
}
163
163
164
164
mod recvfrom {
165
+ use nix:: Result ;
165
166
use nix:: sys:: socket:: * ;
166
167
use std:: thread;
168
+ use super :: * ;
169
+
170
+ const MSG : & ' static [ u8 ] = b"Hello, World!" ;
171
+
172
+ fn sendrecv < F > ( rsock : RawFd , ssock : RawFd , f : F ) -> Option < SockAddr >
173
+ where F : Fn ( RawFd , & [ u8 ] , MsgFlags ) -> Result < usize > + Send + ' static
174
+ {
175
+ let mut buf: [ u8 ; 13 ] = [ 0u8 ; 13 ] ;
176
+ let mut l = 0 ;
177
+ let mut from = None ;
167
178
168
- #[ test]
169
- pub fn datagram ( ) {
170
- let msg = b"Hello, World!" ;
171
- let ( fd1, fd2) = socketpair ( AddressFamily :: Unix , SockType :: Datagram ,
172
- None , SockFlag :: empty ( ) ) . unwrap ( ) ;
173
179
let send_thread = thread:: spawn ( move || {
174
180
let mut l = 0 ;
175
- while l < std:: mem:: size_of_val ( msg) {
176
- let flags = MsgFlags :: empty ( ) ;
177
- l += send ( fd1, & msg[ l..] , flags) . unwrap ( ) ;
181
+ while l < std:: mem:: size_of_val ( MSG ) {
182
+ l += f ( ssock, & MSG [ l..] , MsgFlags :: empty ( ) ) . unwrap ( ) ;
178
183
}
179
184
} ) ;
180
185
181
- let mut buf: [ u8 ; 13 ] = [ 0u8 ; 13 ] ;
182
- let mut l = 0 ;
183
-
184
- while l < std:: mem:: size_of_val ( msg) {
185
- let ( len, from) = recvfrom ( fd2, & mut buf[ l..] ) . unwrap ( ) ;
186
+ while l < std:: mem:: size_of_val ( MSG ) {
187
+ let ( len, from_) = recvfrom ( rsock, & mut buf[ l..] ) . unwrap ( ) ;
188
+ from = from_;
186
189
l += len;
187
- assert_eq ! ( AddressFamily :: Unix , from. unwrap( ) . family( ) ) ;
188
190
}
189
- assert_eq ! ( & buf, msg ) ;
191
+ assert_eq ! ( & buf, MSG ) ;
190
192
send_thread. join ( ) . unwrap ( ) ;
193
+ from
191
194
}
192
195
193
196
#[ test]
194
197
pub fn stream ( ) {
195
- let msg = b"Hello, World!" ;
196
- let ( fd1, fd2) = socketpair ( AddressFamily :: Unix , SockType :: Stream ,
198
+ let ( fd2, fd1) = socketpair ( AddressFamily :: Unix , SockType :: Stream ,
197
199
None , SockFlag :: empty ( ) ) . unwrap ( ) ;
198
- let send_thread = thread:: spawn ( move || {
199
- let mut l = 0 ;
200
- while l < std:: mem:: size_of_val ( msg) {
201
- let flags = MsgFlags :: empty ( ) ;
202
- l += send ( fd1, & msg[ l..] , flags) . unwrap ( ) ;
203
- }
200
+ // Ignore from for stream sockets
201
+ let _ = sendrecv ( fd1, fd2, |s, m, flags| {
202
+ send ( s, m, flags)
204
203
} ) ;
204
+ }
205
205
206
- let mut buf: [ u8 ; 13 ] = [ 0u8 ; 13 ] ;
207
- let mut l = 0 ;
208
-
209
- while l < std:: mem:: size_of_val ( msg) {
210
- let ( len, _from) = recvfrom ( fd2, & mut buf[ l..] ) . unwrap ( ) ;
211
- l += len;
212
- // Ignore _from for stream sockets
213
- }
214
- assert_eq ! ( & buf, msg) ;
215
- send_thread. join ( ) . unwrap ( ) ;
206
+ #[ test]
207
+ pub fn udp ( ) {
208
+ let std_sa = SocketAddr :: from_str ( "127.0.0.1:6789" ) . unwrap ( ) ;
209
+ let inet_addr = InetAddr :: from_std ( & std_sa) ;
210
+ let sock_addr = SockAddr :: new_inet ( inet_addr) ;
211
+ let rsock = socket ( AddressFamily :: Inet ,
212
+ SockType :: Datagram ,
213
+ SockFlag :: empty ( ) ,
214
+ None
215
+ ) . unwrap ( ) ;
216
+ bind ( rsock, & sock_addr) . unwrap ( ) ;
217
+ let ssock = socket (
218
+ AddressFamily :: Inet ,
219
+ SockType :: Datagram ,
220
+ SockFlag :: empty ( ) ,
221
+ None ,
222
+ ) . expect ( "send socket failed" ) ;
223
+ let from = sendrecv ( rsock, ssock, move |s, m, flags| {
224
+ sendto ( s, m, & sock_addr, flags)
225
+ } ) ;
226
+ // UDP sockets should set the from address
227
+ assert_eq ! ( AddressFamily :: Inet , from. unwrap( ) . family( ) ) ;
216
228
}
217
229
}
218
230
0 commit comments