@@ -95,10 +95,6 @@ type handshakeTransport struct {
95
95
96
96
// The session ID or nil if first kex did not complete yet.
97
97
sessionID []byte
98
-
99
- // True if the first ext info message has been sent immediately following
100
- // SSH_MSG_NEWKEYS, false otherwise.
101
- extInfoSent bool
102
98
}
103
99
104
100
type pendingKex struct {
@@ -625,7 +621,8 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
625
621
return err
626
622
}
627
623
628
- if t .sessionID == nil {
624
+ firstKeyExchange := t .sessionID == nil
625
+ if firstKeyExchange {
629
626
t .sessionID = result .H
630
627
}
631
628
result .SessionID = t .sessionID
@@ -643,29 +640,27 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
643
640
}
644
641
645
642
if ! isClient {
646
- // We're on the server side, see if the client sent the extension signal
647
- if ! t .extInfoSent && contains (clientInit .KexAlgos , extInfoClient ) {
648
- // The other side supports ext info, an ext info message hasn't been sent this session,
649
- // and we have at least one extension enabled, so send an SSH_MSG_EXT_INFO message.
643
+ // We're on the server side, if this is the first key exchange
644
+ // see if the client sent the extension signal
645
+ if firstKeyExchange && contains (clientInit .KexAlgos , extInfoClient ) {
646
+ // The other side supports ext info, and this is the first key exchange,
647
+ // so send an SSH_MSG_EXT_INFO message.
650
648
extensions := map [string ][]byte {}
651
- // We're the server, the client supports SSH_MSG_EXT_INFO and server-sig-algs
652
- // is enabled. Prepare the server-sig-algos extension message to send.
649
+ // Prepare the server-sig-algos extension message to send.
653
650
extensions [extServerSigAlgs ] = []byte (strings .Join (supportedServerSigAlgs , "," ))
654
- var payload []byte
655
- for k , v := range extensions {
656
- payload = appendInt (payload , len (k ))
657
- payload = append (payload , k ... )
658
- payload = appendInt (payload , len (v ))
659
- payload = append (payload , v ... )
660
- }
661
- extInfo := extInfoMsg {
651
+
652
+ extInfo := & extInfoMsg {
662
653
NumExtensions : uint32 (len (extensions )),
663
- Payload : payload ,
664
654
}
665
- if err := t .conn .writePacket (Marshal (& extInfo )); err != nil {
655
+ for k , v := range extensions {
656
+ extInfo .Payload = appendInt (extInfo .Payload , len (k ))
657
+ extInfo .Payload = append (extInfo .Payload , k ... )
658
+ extInfo .Payload = appendInt (extInfo .Payload , len (v ))
659
+ extInfo .Payload = append (extInfo .Payload , v ... )
660
+ }
661
+ if err := t .conn .writePacket (Marshal (extInfo )); err != nil {
666
662
return err
667
663
}
668
- t .extInfoSent = true
669
664
}
670
665
}
671
666
0 commit comments