Skip to content

Commit 71b413a

Browse files
authored
Merge pull request #579 from bgrainger/server_spn
Support MariaDB auth_gssapi authentication.
2 parents 96ac5d2 + de210cd commit 71b413a

File tree

9 files changed

+350
-2
lines changed

9 files changed

+350
-2
lines changed

src/MySqlConnector/Core/ConnectionSettings.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ public ConnectionSettings(MySqlConnectionStringBuilder csb)
7979
Keepalive = csb.Keepalive;
8080
PersistSecurityInfo = csb.PersistSecurityInfo;
8181
ServerRsaPublicKeyFile = csb.ServerRsaPublicKeyFile;
82+
ServerSPN = csb.ServerSPN;
8283
TreatTinyAsBoolean = csb.TreatTinyAsBoolean;
8384
UseAffectedRows = csb.UseAffectedRows;
8485
UseCompression = csb.UseCompression;
@@ -157,6 +158,7 @@ private static MySqlGuidFormat GetEffectiveGuidFormat(MySqlGuidFormat guidFormat
157158
public uint Keepalive { get; }
158159
public bool PersistSecurityInfo { get; }
159160
public string ServerRsaPublicKeyFile { get; }
161+
public string ServerSPN { get; }
160162
public bool TreatTinyAsBoolean { get; }
161163
public bool UseAffectedRows { get; }
162164
public bool UseCompression { get; }

src/MySqlConnector/Core/ServerSession.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,9 @@ private async Task<PayloadData> SwitchAuthenticationAsync(ConnectionSettings cs,
500500
return await SendClearPasswordAsync(cs, ioBehavior, cancellationToken).ConfigureAwait(false);
501501
}
502502

503+
case "auth_gssapi_client":
504+
return await AuthGSSAPI.AuthenticateAsync(cs, switchRequest.Data, this, ioBehavior, cancellationToken).ConfigureAwait(false);
505+
503506
case "mysql_old_password":
504507
Log.Error("Session{0} is requesting AuthenticationMethod '{1}' which is not supported", m_logArguments);
505508
throw new NotSupportedException("'MySQL Server is requesting the insecure pre-4.1 auth mechanism (mysql_old_password). The user password must be upgraded; see https://dev.mysql.com/doc/refman/5.7/en/account-upgrades.html.");

src/MySqlConnector/MySql.Data.MySqlClient/MySqlConnectionStringBuilder.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,12 @@ public string ServerRsaPublicKeyFile
261261
set => MySqlConnectionStringOption.ServerRsaPublicKeyFile.SetValue(this, value);
262262
}
263263

264+
public string ServerSPN
265+
{
266+
get => MySqlConnectionStringOption.ServerSPN.GetValue(this);
267+
set => MySqlConnectionStringOption.ServerSPN.SetValue(this, value);
268+
}
269+
264270
public bool TreatTinyAsBoolean
265271
{
266272
get => MySqlConnectionStringOption.TreatTinyAsBoolean.GetValue(this);
@@ -371,6 +377,7 @@ internal abstract class MySqlConnectionStringOption
371377
public static readonly MySqlConnectionStringOption<bool> OldGuids;
372378
public static readonly MySqlConnectionStringOption<bool> PersistSecurityInfo;
373379
public static readonly MySqlConnectionStringOption<string> ServerRsaPublicKeyFile;
380+
public static readonly MySqlConnectionStringOption<string> ServerSPN;
374381
public static readonly MySqlConnectionStringOption<bool> TreatTinyAsBoolean;
375382
public static readonly MySqlConnectionStringOption<bool> UseAffectedRows;
376383
public static readonly MySqlConnectionStringOption<bool> UseCompression;
@@ -565,6 +572,10 @@ static MySqlConnectionStringOption()
565572
keys: new[] { "ServerRSAPublicKeyFile", "Server RSA Public Key File" },
566573
defaultValue: null));
567574

575+
AddOption(ServerSPN = new MySqlConnectionStringOption<string>(
576+
keys: new[] { "Server SPN", "ServerSPN" },
577+
defaultValue: null));
578+
568579
AddOption(TreatTinyAsBoolean = new MySqlConnectionStringOption<bool>(
569580
keys: new[] { "Treat Tiny As Boolean", "TreatTinyAsBoolean" },
570581
defaultValue: true));
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
using System;
2+
using System.IO;
3+
using System.Net;
4+
using System.Net.Security;
5+
using System.Security;
6+
using System.Security.Authentication;
7+
using System.Text;
8+
using System.Threading;
9+
using System.Threading.Tasks;
10+
using MySqlConnector.Core;
11+
using MySqlConnector.Utilities;
12+
13+
namespace MySqlConnector.Protocol.Serialization
14+
{
15+
internal class NegotiateStreamConstants
16+
{
17+
public const int HeaderLength = 5;
18+
public const byte MajorVersion = 1;
19+
public const byte MinorVersion = 0;
20+
public const byte HandshakeDone = 0x14;
21+
public const byte HandshakeError = 0x15;
22+
public const byte HandshakeInProgress = 0x16;
23+
public const ushort MaxPayloadLength = ushort.MaxValue;
24+
}
25+
26+
/// <summary>
27+
/// Helper class to translate NegotiateStream framing for SPNEGO token
28+
/// into MySQL protocol packets.
29+
///
30+
/// Serves as underlying stream for System.Net.NegotiateStream
31+
/// to perform MariaDB's auth_gssapi_client authentication.
32+
///
33+
/// NegotiateStream protocol is described in e.g here
34+
/// https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-NNS/[MS-NNS].pdf
35+
/// We only use Handshake Messages for authentication.
36+
/// </summary>
37+
38+
internal class NegotiateToMySqlConverterStream : Stream
39+
{
40+
bool m_clientHandshakeDone;
41+
42+
MemoryStream m_readBuffer;
43+
MemoryStream m_writeBuffer;
44+
int m_writePayloadLength;
45+
ServerSession m_serverSession;
46+
IOBehavior m_ioBehavior;
47+
CancellationToken m_cancellationToken;
48+
49+
public PayloadData MySQLProtocolPayload { get; private set; }
50+
public NegotiateToMySqlConverterStream(ServerSession serverSession, IOBehavior ioBehavior, CancellationToken cancellationToken)
51+
{
52+
m_serverSession = serverSession;
53+
m_readBuffer = new MemoryStream();
54+
m_writeBuffer = new MemoryStream();
55+
m_ioBehavior = ioBehavior;
56+
m_cancellationToken = cancellationToken;
57+
}
58+
59+
static void CreateNegotiateStreamMessageHeader(byte[] buffer, int offset, byte messageId, long payloadLength)
60+
{
61+
buffer[offset] = messageId;
62+
buffer[offset+1] = NegotiateStreamConstants.MajorVersion;
63+
buffer[offset+2] = NegotiateStreamConstants.MinorVersion;
64+
buffer[offset+3] = (byte) (payloadLength >> 8);
65+
buffer[offset+4] = (byte) (payloadLength & 0xff);
66+
}
67+
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
68+
{
69+
var bytesRead = 0;
70+
71+
if (m_readBuffer.Length == m_readBuffer.Position)
72+
{
73+
if (count < NegotiateStreamConstants.HeaderLength)
74+
throw new InvalidDataException("Unexpected call to read less then NegotiateStream header");
75+
76+
if (m_clientHandshakeDone)
77+
{
78+
// NegotiateStream protocol expects server to send "handshake done"
79+
// empty message at the end of handshake.
80+
CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeDone, 0);
81+
return NegotiateStreamConstants.HeaderLength;
82+
}
83+
// Read and cache packet from server.
84+
var payload = await m_serverSession.ReceiveReplyAsync(m_ioBehavior, cancellationToken).ConfigureAwait(false);
85+
var segment = payload.ArraySegment;
86+
87+
if (segment.Count > NegotiateStreamConstants.MaxPayloadLength)
88+
throw new InvalidDataException(String.Format("Payload too big for NegotiateStream - {0} bytes", segment.Count));
89+
90+
// Check the first byte of the incoming packet.
91+
// It can be an OK packet indicating end of server processing,
92+
// or it can be 0x01 prefix we must strip off - 0x01 server masks special bytes, e.g 0xff, 0xfe in the payload
93+
// during pluggable authentication packet exchanges.
94+
var segmentOffset = segment.Offset;
95+
var segmentCount = segment.Count;
96+
97+
switch (segment.Array[segment.Offset])
98+
{
99+
case 0x0:
100+
MySQLProtocolPayload = payload;
101+
CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeDone, 0);
102+
return NegotiateStreamConstants.HeaderLength;
103+
case 0x1:
104+
segmentOffset++;
105+
segmentCount--;
106+
break;
107+
}
108+
109+
m_readBuffer = new MemoryStream(segment.Array, segmentOffset, segmentCount);
110+
CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeInProgress, m_readBuffer.Length);
111+
bytesRead = NegotiateStreamConstants.HeaderLength;
112+
offset += bytesRead;
113+
count -= bytesRead;
114+
}
115+
if (count > 0)
116+
{
117+
// Return cached data.
118+
bytesRead += m_readBuffer.Read(buffer, offset, count);
119+
}
120+
return bytesRead;
121+
}
122+
123+
public override int Read(byte[] buffer, int offset, int count) => ReadAsync(buffer, offset, count, m_cancellationToken).GetAwaiter().GetResult();
124+
125+
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
126+
{
127+
if (m_writePayloadLength == 0)
128+
{
129+
// The message header was not read yet.
130+
if (count < NegotiateStreamConstants.HeaderLength)
131+
// For simplicity, we expect header to be written in one go
132+
throw new InvalidDataException("Cannot parse NegotiateStream handshake message header");
133+
134+
// Parse NegotiateStream handshake header
135+
var messageId = buffer[offset+0];
136+
var majorProtocolVersion = buffer[offset+1];
137+
var minorProtocolVersion = buffer[offset+2];
138+
var payloadSizeLow = buffer[offset+4];
139+
var payloadSizeHigh = buffer[offset+3];
140+
141+
142+
if (majorProtocolVersion != NegotiateStreamConstants.MajorVersion ||
143+
minorProtocolVersion != NegotiateStreamConstants.MinorVersion)
144+
{
145+
throw new FormatException(
146+
String.Format("Unknown version of NegotiateStream protocol {0}.{1}, expected {2}.{3}",
147+
majorProtocolVersion, minorProtocolVersion,
148+
NegotiateStreamConstants.MajorVersion, NegotiateStreamConstants.MinorVersion));
149+
}
150+
if (messageId != NegotiateStreamConstants.HandshakeDone &&
151+
messageId != NegotiateStreamConstants.HandshakeError &&
152+
messageId != NegotiateStreamConstants.HandshakeInProgress)
153+
{
154+
throw new FormatException(
155+
String.Format("Invalid NegotiateStream MessageId 0x{0:X2}", messageId));
156+
}
157+
158+
m_writePayloadLength = (int) payloadSizeLow + ((int) payloadSizeHigh << 8);
159+
if (messageId == NegotiateStreamConstants.HandshakeDone)
160+
m_clientHandshakeDone = true;
161+
162+
count -= NegotiateStreamConstants.HeaderLength;
163+
}
164+
165+
if (count == 0)
166+
return;
167+
168+
if (count + m_writeBuffer.Length > m_writePayloadLength)
169+
throw new InvalidDataException("Attempt to write more than a single message");
170+
171+
PayloadData payload;
172+
if (count < m_writePayloadLength)
173+
{
174+
m_writeBuffer.Write(buffer, offset, count);
175+
if (m_writeBuffer.Length < m_writePayloadLength)
176+
// The message is only partially written
177+
return;
178+
179+
var payloadBytes = m_writeBuffer.ToArray();
180+
payload = new PayloadData(new ArraySegment<byte>(payloadBytes, 0, (int) m_writeBuffer.Length));
181+
m_writeBuffer.SetLength(0);
182+
}
183+
else
184+
{
185+
// full payload provided
186+
payload = new PayloadData(new ArraySegment<byte>(buffer, offset, m_writePayloadLength));
187+
}
188+
await m_serverSession.SendReplyAsync(payload, m_ioBehavior, cancellationToken).ConfigureAwait(false);
189+
// Need to parse NegotiateStream header next time
190+
m_writePayloadLength = 0;
191+
}
192+
193+
public override void Write(byte[] buffer, int offset, int count) => WriteAsync(buffer, offset, count, m_cancellationToken).GetAwaiter().GetResult();
194+
195+
public override bool CanRead => true;
196+
197+
public override bool CanSeek => false;
198+
199+
public override bool CanWrite => true;
200+
201+
public override long Length => throw new NotImplementedException();
202+
203+
public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
204+
205+
public override void Flush()
206+
{
207+
}
208+
209+
public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
210+
211+
public override void SetLength(long value) => throw new NotImplementedException();
212+
213+
}
214+
internal class AuthGSSAPI
215+
{
216+
private static string GetServicePrincipalName(byte[] switchRequest)
217+
{
218+
var reader = new ByteArrayReader(switchRequest.AsSpan());
219+
return Encoding.UTF8.GetString(reader.ReadNullOrEofTerminatedByteString());
220+
}
221+
public static async Task<PayloadData> AuthenticateAsync(ConnectionSettings cs, byte[] switchRequestPayloadData,
222+
ServerSession session, IOBehavior ioBehavior, CancellationToken cancellationToken)
223+
{
224+
using (var innerStream = new NegotiateToMySqlConverterStream(session, ioBehavior, cancellationToken))
225+
using (var negotiateStream = new NegotiateStream(innerStream))
226+
{
227+
var targetName =cs.ServerSPN ?? GetServicePrincipalName(switchRequestPayloadData);
228+
#if NETSTANDARD1_3
229+
await negotiateStream.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, targetName).ConfigureAwait(false);
230+
#else
231+
if (ioBehavior == IOBehavior.Synchronous)
232+
{
233+
negotiateStream.AuthenticateAsClient(CredentialCache.DefaultNetworkCredentials, targetName);
234+
}
235+
else
236+
{
237+
await negotiateStream.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, targetName).ConfigureAwait(false);
238+
}
239+
#endif
240+
if (cs.ServerSPN != null && !negotiateStream.IsMutuallyAuthenticated)
241+
{
242+
// Negotiate used NTLM fallback, server name cannot be verified.
243+
throw new AuthenticationException(String.Format(
244+
"GSSAPI : Unable to verify server principal name using authentication type {0}",
245+
negotiateStream.RemoteIdentity?.AuthenticationType));
246+
}
247+
if (innerStream.MySQLProtocolPayload.ArraySegment.Array != null)
248+
// return already pre-read OK packet.
249+
return innerStream.MySQLProtocolPayload;
250+
251+
// Read final OK packet from server
252+
return await session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
253+
}
254+
}
255+
}
256+
}

tests/MySqlConnector.Tests/MySqlConnectionStringBuilderTests.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ public void Defaults()
5757
Assert.Equal("", csb.Server);
5858
#if !BASELINE
5959
Assert.Null(csb.ServerRsaPublicKeyFile);
60-
#endif
61-
#if !BASELINE
60+
Assert.Null(csb.ServerSPN);
6261
Assert.Equal(MySqlSslMode.Preferred, csb.SslMode);
6362
#else
6463
Assert.Equal(MySqlSslMode.Required, csb.SslMode);
@@ -108,6 +107,7 @@ public void ParseConnectionString()
108107
"server rsa public key file=rsa.pem;" +
109108
"load balance=random;" +
110109
"guidformat=timeswapbinary16;" +
110+
"server spn=mariadb/[email protected];" +
111111
#endif
112112
"ignore prepare=false;" +
113113
"interactive=true;" +
@@ -155,6 +155,7 @@ public void ParseConnectionString()
155155
Assert.Equal("rsa.pem", csb.ServerRsaPublicKeyFile);
156156
Assert.Equal(MySqlLoadBalance.Random, csb.LoadBalance);
157157
Assert.Equal(MySqlGuidFormat.TimeSwapBinary16, csb.GuidFormat);
158+
Assert.Equal("mariadb/[email protected]", csb.ServerSPN);
158159
#endif
159160
Assert.False(csb.IgnorePrepare);
160161
Assert.True(csb.InteractiveSession);

tests/SideBySide/AppConfig.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ public static class AppConfig
3636

3737
public static string PasswordlessUser => Config.GetValue<string>("Data:PasswordlessUser");
3838

39+
public static string GSSAPIUser => Config.GetValue<string>("Data:GSSAPIUser");
40+
41+
public static bool HasKerberos => Config.GetValue<bool>("Data:HasKerberos");
42+
3943
public static string SecondaryDatabase => Config.GetValue<string>("Data:SecondaryDatabase");
4044

4145
private static ServerFeatures UnsupportedFeatures => (ServerFeatures) Enum.Parse(typeof(ServerFeatures), Config.GetValue<string>("Data:UnsupportedFeatures"));
@@ -69,6 +73,14 @@ public static MySqlConnectionStringBuilder CreateCachingSha2ConnectionStringBuil
6973
return csb;
7074
}
7175

76+
public static MySqlConnectionStringBuilder CreateGSSAPIConnectionStringBuilder()
77+
{
78+
var csb = CreateConnectionStringBuilder();
79+
csb.UserID = GSSAPIUser;
80+
csb.Database = null;
81+
return csb;
82+
}
83+
7284
// tests can run much slower in CI environments
7385
public static int TimeoutDelayFactor { get; } = Environment.GetEnvironmentVariable("APPVEYOR") == "True" || Environment.GetEnvironmentVariable("TRAVIS") == "true" ? 6 : 1;
7486
}

tests/SideBySide/ConfigSettings.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@ public enum ConfigSettings
1717
TcpConnection = 0x200,
1818
SecondaryDatabase = 0x400,
1919
KnownClientCertificate = 0x800,
20+
GSSAPIUser = 0x1000,
21+
HasKerberos = 0x2000
2022
}
2123
}

0 commit comments

Comments
 (0)