|
| 1 | +using System; |
| 2 | +using System.IO; |
| 3 | +using System.Net; |
| 4 | +using System.Net.Security; |
| 5 | +using System.Text; |
| 6 | +using System.Threading; |
| 7 | +using System.Threading.Tasks; |
| 8 | +using MySqlConnector.Core; |
| 9 | +using MySqlConnector.Utilities; |
| 10 | + |
| 11 | +namespace MySqlConnector.Protocol.Serialization |
| 12 | +{ |
| 13 | + internal class NegotiateStreamConstants |
| 14 | + { |
| 15 | + public const int HeaderLength = 5; |
| 16 | + public const byte MajorVersion = 1; |
| 17 | + public const byte MinorVersion = 0; |
| 18 | + public const byte HandshakeDone = 0x14; |
| 19 | + public const byte HandshakeError = 0x15; |
| 20 | + public const byte HandshakeInProgress = 0x16; |
| 21 | + public const ushort MaxPayloadLength = ushort.MaxValue; |
| 22 | + } |
| 23 | + |
| 24 | + /// <summary> |
| 25 | + /// Helper class to translate NegotiateStream framing for SPNEGO token |
| 26 | + /// into MySQL protocol packets. |
| 27 | + /// |
| 28 | + /// Serves as underlying stream for System.Net.NegotiateStream |
| 29 | + /// to perform MariaDB's auth_gssapi_client authentication. |
| 30 | + /// |
| 31 | + /// NegotiateStream protocol is described in e.g here |
| 32 | + /// https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-NNS/[MS-NNS].pdf |
| 33 | + /// We only use Handshake Messages for authentication. |
| 34 | + /// </summary> |
| 35 | + |
| 36 | + internal class NegotiateToMySqlConverterStream : Stream |
| 37 | + { |
| 38 | + bool m_clientHandshakeDone; |
| 39 | + |
| 40 | + MemoryStream m_readBuffer; |
| 41 | + MemoryStream m_writeBuffer; |
| 42 | + int m_writePayloadLength; |
| 43 | + ServerSession m_serverSession; |
| 44 | + IOBehavior m_ioBehavior; |
| 45 | + CancellationToken m_cancellationToken; |
| 46 | + |
| 47 | + public PayloadData MySQLProtocolPayload { get; private set; } |
| 48 | + public NegotiateToMySqlConverterStream(ServerSession serverSession, IOBehavior ioBehavior, CancellationToken cancellationToken) |
| 49 | + { |
| 50 | + m_serverSession = serverSession; |
| 51 | + m_readBuffer = new MemoryStream(); |
| 52 | + m_writeBuffer = new MemoryStream(); |
| 53 | + m_ioBehavior = ioBehavior; |
| 54 | + m_cancellationToken = cancellationToken; |
| 55 | + } |
| 56 | + |
| 57 | + static void CreateNegotiateStreamMessageHeader(byte[] buffer, int offset, byte messageId, long payloadLength) |
| 58 | + { |
| 59 | + buffer[offset] = messageId; |
| 60 | + buffer[offset+1] = NegotiateStreamConstants.MajorVersion; |
| 61 | + buffer[offset+2] = NegotiateStreamConstants.MinorVersion; |
| 62 | + buffer[offset+3] = (byte) (payloadLength >> 8); |
| 63 | + buffer[offset+4] = (byte) (payloadLength & 0xff); |
| 64 | + } |
| 65 | + public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) |
| 66 | + { |
| 67 | + var bytesRead = 0; |
| 68 | + |
| 69 | + if (m_readBuffer.Length == m_readBuffer.Position) |
| 70 | + { |
| 71 | + if (count < NegotiateStreamConstants.HeaderLength) |
| 72 | + throw new InvalidDataException("Unexpected call to read less then NegotiateStream header"); |
| 73 | + |
| 74 | + if (m_clientHandshakeDone) |
| 75 | + { |
| 76 | + // NegotiateStream protocol expects server to send "handshake done" |
| 77 | + // empty message at the end of handshake. |
| 78 | + CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeDone, 0); |
| 79 | + return NegotiateStreamConstants.HeaderLength; |
| 80 | + } |
| 81 | + // Read and cache packet from server. |
| 82 | + var payload = await m_serverSession.ReceiveReplyAsync(m_ioBehavior, cancellationToken); |
| 83 | + var segment = payload.ArraySegment; |
| 84 | + |
| 85 | + if (segment.Count > NegotiateStreamConstants.MaxPayloadLength) |
| 86 | + throw new InvalidDataException(String.Format("Payload too big for NegotiateStream - {0} bytes", segment.Count)); |
| 87 | + |
| 88 | + // Check the first byte of the incoming packet. |
| 89 | + // It can be an OK packet indicating end of server processing, |
| 90 | + // or it can be 0x01 prefix we must strip off - 0x01 server masks special bytes, e.g 0xff, 0xfe in the payload |
| 91 | + // during pluggable authentication packet exchanges. |
| 92 | + var segmentOffset = segment.Offset; |
| 93 | + var segmentCount = segment.Count; |
| 94 | + |
| 95 | + switch (segment.Array[segment.Offset]) |
| 96 | + { |
| 97 | + case 0x0: |
| 98 | + MySQLProtocolPayload = payload; |
| 99 | + CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeDone, 0); |
| 100 | + return NegotiateStreamConstants.HeaderLength; |
| 101 | + case 0x1: |
| 102 | + segmentOffset++; |
| 103 | + segmentCount--; |
| 104 | + break; |
| 105 | + } |
| 106 | + |
| 107 | + m_readBuffer = new MemoryStream(segment.Array, segmentOffset, segmentCount); |
| 108 | + CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeInProgress, m_readBuffer.Length); |
| 109 | + bytesRead = NegotiateStreamConstants.HeaderLength; |
| 110 | + offset += bytesRead; |
| 111 | + count -= bytesRead; |
| 112 | + } |
| 113 | + if (count > 0) |
| 114 | + { |
| 115 | + // Return cached data. |
| 116 | + bytesRead += m_readBuffer.Read(buffer, offset, count); |
| 117 | + } |
| 118 | + return bytesRead; |
| 119 | + } |
| 120 | + |
| 121 | + public override int Read(byte[] buffer, int offset, int count) => ReadAsync(buffer, offset, count, m_cancellationToken).Result; |
| 122 | + |
| 123 | + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) |
| 124 | + { |
| 125 | + if (m_writePayloadLength == 0) |
| 126 | + { |
| 127 | + // The message header was not read yet. |
| 128 | + if (count < NegotiateStreamConstants.HeaderLength) |
| 129 | + // For simplicity, we expect header to be written in one go |
| 130 | + throw new InvalidDataException("Cannot parse NegotiateStream handshake message header"); |
| 131 | + |
| 132 | + // Parse NegotiateStream handshake header |
| 133 | + var messageId = buffer[offset+0]; |
| 134 | + var majorProtocolVersion = buffer[offset+1]; |
| 135 | + var minorProtocolVersion = buffer[offset+2]; |
| 136 | + var payloadSizeLow = buffer[offset+4]; |
| 137 | + var payloadSizeHigh = buffer[offset+3]; |
| 138 | + |
| 139 | + |
| 140 | + if (majorProtocolVersion != NegotiateStreamConstants.MajorVersion || |
| 141 | + minorProtocolVersion != NegotiateStreamConstants.MinorVersion) |
| 142 | + { |
| 143 | + throw new FormatException( |
| 144 | + String.Format("Unknown version of NegotiateStream protocol {0}.{1}, expected {2}.{3}", |
| 145 | + majorProtocolVersion, minorProtocolVersion, |
| 146 | + NegotiateStreamConstants.MajorVersion, NegotiateStreamConstants.MinorVersion)); |
| 147 | + } |
| 148 | + if (messageId != NegotiateStreamConstants.HandshakeDone && |
| 149 | + messageId != NegotiateStreamConstants.HandshakeError && |
| 150 | + messageId != NegotiateStreamConstants.HandshakeInProgress) |
| 151 | + { |
| 152 | + throw new FormatException( |
| 153 | + String.Format("Invalid NegotiateStream MessageId 0x{0:X2}", messageId)); |
| 154 | + } |
| 155 | + |
| 156 | + m_writePayloadLength = (int) payloadSizeLow + ((int) payloadSizeHigh << 8); |
| 157 | + if (messageId == NegotiateStreamConstants.HandshakeDone) |
| 158 | + m_clientHandshakeDone = true; |
| 159 | + |
| 160 | + count -= NegotiateStreamConstants.HeaderLength; |
| 161 | + } |
| 162 | + |
| 163 | + if (count == 0) |
| 164 | + return; |
| 165 | + |
| 166 | + if (count + m_writeBuffer.Length > m_writePayloadLength) |
| 167 | + throw new InvalidDataException("Attempt to write more than a single message"); |
| 168 | + |
| 169 | + PayloadData payload; |
| 170 | + if (count < m_writePayloadLength) |
| 171 | + { |
| 172 | + m_writeBuffer.Write(buffer, offset, count); |
| 173 | + if (m_writeBuffer.Length < m_writePayloadLength) |
| 174 | + // The message is only partially written |
| 175 | + return; |
| 176 | + |
| 177 | + var payloadBytes = m_writeBuffer.ToArray(); |
| 178 | + payload = new PayloadData(new ArraySegment<byte>(payloadBytes, 0, (int) m_writeBuffer.Length)); |
| 179 | + m_writeBuffer.SetLength(0); |
| 180 | + } |
| 181 | + else |
| 182 | + { |
| 183 | + // full payload provided |
| 184 | + payload = new PayloadData(new ArraySegment<byte>(buffer, offset, m_writePayloadLength)); |
| 185 | + } |
| 186 | + await m_serverSession.SendReplyAsync(payload, m_ioBehavior, cancellationToken); |
| 187 | + // Need to parse NegotiateStream header next time |
| 188 | + m_writePayloadLength = 0; |
| 189 | + } |
| 190 | + |
| 191 | + public override void Write(byte[] buffer, int offset, int count) => WriteAsync(buffer, offset, count, m_cancellationToken).Wait(); |
| 192 | + |
| 193 | + public override bool CanRead => true; |
| 194 | + |
| 195 | + public override bool CanSeek => false; |
| 196 | + |
| 197 | + public override bool CanWrite => true; |
| 198 | + |
| 199 | + public override long Length => throw new NotImplementedException(); |
| 200 | + |
| 201 | + public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } |
| 202 | + |
| 203 | + public override void Flush() |
| 204 | + { |
| 205 | + } |
| 206 | + |
| 207 | + public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException(); |
| 208 | + |
| 209 | + public override void SetLength(long value) => throw new NotImplementedException(); |
| 210 | + |
| 211 | + } |
| 212 | + internal class AuthGSSAPI |
| 213 | + { |
| 214 | + private static string GetServicePrincipalName(byte[] switchRequest) |
| 215 | + { |
| 216 | + var reader = new ByteArrayReader(switchRequest.AsSpan()); |
| 217 | + return Encoding.UTF8.GetString(reader.ReadNullOrEofTerminatedByteString()); |
| 218 | + } |
| 219 | + public static async Task<PayloadData> AuthenticateAsync(byte[] switchRequestPayloadData, |
| 220 | + ServerSession session, IOBehavior ioBehavior, CancellationToken cancellationToken) |
| 221 | + { |
| 222 | + using (var innerStream = new NegotiateToMySqlConverterStream(session, ioBehavior, cancellationToken)) |
| 223 | + using (var negotiateStream = new NegotiateStream(innerStream)) |
| 224 | + { |
| 225 | + var targetName = GetServicePrincipalName(switchRequestPayloadData); |
| 226 | + await negotiateStream.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, targetName); |
| 227 | + |
| 228 | + if (innerStream.MySQLProtocolPayload.ArraySegment.Array != null) |
| 229 | + // return already pre-read OK packet. |
| 230 | + return innerStream.MySQLProtocolPayload; |
| 231 | + |
| 232 | + // Read final OK packet from server |
| 233 | + return await session.ReceiveReplyAsync(ioBehavior, cancellationToken); |
| 234 | + } |
| 235 | + } |
| 236 | + } |
| 237 | +} |
0 commit comments