@@ -9,42 +9,57 @@ package ssh
9
9
import (
10
10
"crypto/rand"
11
11
"reflect"
12
+ "sync"
12
13
"testing"
13
14
)
14
15
16
+ // Runs multiple key exchanges concurrent to detect potential data races with
17
+ // kex obtained from the global kexAlgoMap.
18
+ // This test needs to be executed using the race detector in order to detect
19
+ // race conditions.
15
20
func TestKexes (t * testing.T ) {
16
21
type kexResultErr struct {
17
22
result * kexResult
18
23
err error
19
24
}
20
25
21
26
for name , kex := range kexAlgoMap {
22
- a , b := memPipe ()
27
+ t .Run (name , func (t * testing.T ) {
28
+ wg := sync.WaitGroup {}
29
+ for i := 0 ; i < 3 ; i ++ {
30
+ wg .Add (1 )
31
+ go func () {
32
+ defer wg .Done ()
33
+ a , b := memPipe ()
23
34
24
- s := make (chan kexResultErr , 1 )
25
- c := make (chan kexResultErr , 1 )
26
- var magics handshakeMagics
27
- go func () {
28
- r , e := kex .Client (a , rand .Reader , & magics )
29
- a .Close ()
30
- c <- kexResultErr {r , e }
31
- }()
32
- go func () {
33
- r , e := kex .Server (b , rand .Reader , & magics , testSigners ["ecdsa" ])
34
- b .Close ()
35
- s <- kexResultErr {r , e }
36
- }()
35
+ s := make (chan kexResultErr , 1 )
36
+ c := make (chan kexResultErr , 1 )
37
+ var magics handshakeMagics
38
+ go func () {
39
+ r , e := kex .Client (a , rand .Reader , & magics )
40
+ a .Close ()
41
+ c <- kexResultErr {r , e }
42
+ }()
43
+ go func () {
44
+ r , e := kex .Server (b , rand .Reader , & magics , testSigners ["ecdsa" ])
45
+ b .Close ()
46
+ s <- kexResultErr {r , e }
47
+ }()
37
48
38
- clientRes := <- c
39
- serverRes := <- s
40
- if clientRes .err != nil {
41
- t .Errorf ("client: %v" , clientRes .err )
42
- }
43
- if serverRes .err != nil {
44
- t .Errorf ("server: %v" , serverRes .err )
45
- }
46
- if ! reflect .DeepEqual (clientRes .result , serverRes .result ) {
47
- t .Errorf ("kex %q: mismatch %#v, %#v" , name , clientRes .result , serverRes .result )
48
- }
49
+ clientRes := <- c
50
+ serverRes := <- s
51
+ if clientRes .err != nil {
52
+ t .Errorf ("client: %v" , clientRes .err )
53
+ }
54
+ if serverRes .err != nil {
55
+ t .Errorf ("server: %v" , serverRes .err )
56
+ }
57
+ if ! reflect .DeepEqual (clientRes .result , serverRes .result ) {
58
+ t .Errorf ("kex %q: mismatch %#v, %#v" , name , clientRes .result , serverRes .result )
59
+ }
60
+ }()
61
+ }
62
+ wg .Wait ()
63
+ })
49
64
}
50
65
}
0 commit comments