Skip to content

Commit 52c37f1

Browse files
committed
Support MariaDB's auth_gssapi authentication plugin
See https://mariadb.com/kb/en/library/authentication-plugin-gssapi/ for plugin details. Signed-off-by: Vladislav Vaintroub <[email protected]>
1 parent 2a3948f commit 52c37f1

File tree

6 files changed

+274
-0
lines changed

6 files changed

+274
-0
lines changed

src/MySqlConnector/Core/ServerSession.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,9 @@ private async Task<PayloadData> SwitchAuthenticationAsync(ConnectionSettings cs,
500500
return await SendClearPasswordAsync(cs, ioBehavior, cancellationToken).ConfigureAwait(false);
501501
}
502502

503+
case "auth_gssapi_client":
504+
return await AuthGSSAPI.AuthenticateAsync(switchRequest.Data, this, ioBehavior, cancellationToken).ConfigureAwait(false);
505+
503506
case "mysql_old_password":
504507
Log.Error("Session{0} is requesting AuthenticationMethod '{1}' which is not supported", m_logArguments);
505508
throw new NotSupportedException("'MySQL Server is requesting the insecure pre-4.1 auth mechanism (mysql_old_password). The user password must be upgraded; see https://dev.mysql.com/doc/refman/5.7/en/account-upgrades.html.");
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
using System;
2+
using System.IO;
3+
using System.Net;
4+
using System.Net.Security;
5+
using System.Text;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using MySqlConnector.Core;
9+
using MySqlConnector.Utilities;
10+
11+
namespace MySqlConnector.Protocol.Serialization
12+
{
13+
internal class NegotiateStreamConstants
14+
{
15+
public const int HeaderLength = 5;
16+
public const byte MajorVersion = 1;
17+
public const byte MinorVersion = 0;
18+
public const byte HandshakeDone = 0x14;
19+
public const byte HandshakeError = 0x15;
20+
public const byte HandshakeInProgress = 0x16;
21+
public const ushort MaxPayloadLength = ushort.MaxValue;
22+
}
23+
24+
/// <summary>
25+
/// Helper class to translate NegotiateStream framing for SPNEGO token
26+
/// into MySQL protocol packets.
27+
///
28+
/// Serves as underlying stream for System.Net.NegotiateStream
29+
/// to perform MariaDB's auth_gssapi_client authentication.
30+
///
31+
/// NegotiateStream protocol is described in e.g here
32+
/// https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-NNS/[MS-NNS].pdf
33+
/// We only use Handshake Messages for authentication.
34+
/// </summary>
35+
36+
internal class NegotiateToMySqlConverterStream : Stream
37+
{
38+
bool m_clientHandshakeDone;
39+
40+
MemoryStream m_readBuffer;
41+
MemoryStream m_writeBuffer;
42+
int m_writePayloadLength;
43+
ServerSession m_serverSession;
44+
IOBehavior m_ioBehavior;
45+
CancellationToken m_cancellationToken;
46+
47+
public PayloadData MySQLProtocolPayload { get; private set; }
48+
public NegotiateToMySqlConverterStream(ServerSession serverSession, IOBehavior ioBehavior, CancellationToken cancellationToken)
49+
{
50+
m_serverSession = serverSession;
51+
m_readBuffer = new MemoryStream();
52+
m_writeBuffer = new MemoryStream();
53+
m_ioBehavior = ioBehavior;
54+
m_cancellationToken = cancellationToken;
55+
}
56+
57+
static void CreateNegotiateStreamMessageHeader(byte[] buffer, int offset, byte messageId, long payloadLength)
58+
{
59+
buffer[offset] = messageId;
60+
buffer[offset+1] = NegotiateStreamConstants.MajorVersion;
61+
buffer[offset+2] = NegotiateStreamConstants.MinorVersion;
62+
buffer[offset+3] = (byte) (payloadLength >> 8);
63+
buffer[offset+4] = (byte) (payloadLength & 0xff);
64+
}
65+
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
66+
{
67+
var bytesRead = 0;
68+
69+
if (m_readBuffer.Length == m_readBuffer.Position)
70+
{
71+
if (count < NegotiateStreamConstants.HeaderLength)
72+
throw new InvalidDataException("Unexpected call to read less then NegotiateStream header");
73+
74+
if (m_clientHandshakeDone)
75+
{
76+
// NegotiateStream protocol expects server to send "handshake done"
77+
// empty message at the end of handshake.
78+
CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeDone, 0);
79+
return NegotiateStreamConstants.HeaderLength;
80+
}
81+
// Read and cache packet from server.
82+
var payload = await m_serverSession.ReceiveReplyAsync(m_ioBehavior, cancellationToken);
83+
var segment = payload.ArraySegment;
84+
85+
if (segment.Count > NegotiateStreamConstants.MaxPayloadLength)
86+
throw new InvalidDataException(String.Format("Payload too big for NegotiateStream - {0} bytes", segment.Count));
87+
88+
// Check the first byte of the incoming packet.
89+
// It can be an OK packet indicating end of server processing,
90+
// or it can be 0x01 prefix we must strip off - 0x01 server masks special bytes, e.g 0xff, 0xfe in the payload
91+
// during pluggable authentication packet exchanges.
92+
var segmentOffset = segment.Offset;
93+
var segmentCount = segment.Count;
94+
95+
switch (segment.Array[segment.Offset])
96+
{
97+
case 0x0:
98+
MySQLProtocolPayload = payload;
99+
CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeDone, 0);
100+
return NegotiateStreamConstants.HeaderLength;
101+
case 0x1:
102+
segmentOffset++;
103+
segmentCount--;
104+
break;
105+
}
106+
107+
m_readBuffer = new MemoryStream(segment.Array, segmentOffset, segmentCount);
108+
CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeInProgress, m_readBuffer.Length);
109+
bytesRead = NegotiateStreamConstants.HeaderLength;
110+
offset += bytesRead;
111+
count -= bytesRead;
112+
}
113+
if (count > 0)
114+
{
115+
// Return cached data.
116+
bytesRead += m_readBuffer.Read(buffer, offset, count);
117+
}
118+
return bytesRead;
119+
}
120+
121+
public override int Read(byte[] buffer, int offset, int count) => ReadAsync(buffer, offset, count, m_cancellationToken).Result;
122+
123+
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
124+
{
125+
if (m_writePayloadLength == 0)
126+
{
127+
// The message header was not read yet.
128+
if (count < NegotiateStreamConstants.HeaderLength)
129+
// For simplicity, we expect header to be written in one go
130+
throw new InvalidDataException("Cannot parse NegotiateStream handshake message header");
131+
132+
// Parse NegotiateStream handshake header
133+
var messageId = buffer[offset+0];
134+
var majorProtocolVersion = buffer[offset+1];
135+
var minorProtocolVersion = buffer[offset+2];
136+
var payloadSizeLow = buffer[offset+4];
137+
var payloadSizeHigh = buffer[offset+3];
138+
139+
140+
if (majorProtocolVersion != NegotiateStreamConstants.MajorVersion ||
141+
minorProtocolVersion != NegotiateStreamConstants.MinorVersion)
142+
{
143+
throw new FormatException(
144+
String.Format("Unknown version of NegotiateStream protocol {0}.{1}, expected {2}.{3}",
145+
majorProtocolVersion, minorProtocolVersion,
146+
NegotiateStreamConstants.MajorVersion, NegotiateStreamConstants.MinorVersion));
147+
}
148+
if (messageId != NegotiateStreamConstants.HandshakeDone &&
149+
messageId != NegotiateStreamConstants.HandshakeError &&
150+
messageId != NegotiateStreamConstants.HandshakeInProgress)
151+
{
152+
throw new FormatException(
153+
String.Format("Invalid NegotiateStream MessageId 0x{0:X2}", messageId));
154+
}
155+
156+
m_writePayloadLength = (int) payloadSizeLow + ((int) payloadSizeHigh << 8);
157+
if (messageId == NegotiateStreamConstants.HandshakeDone)
158+
m_clientHandshakeDone = true;
159+
160+
count -= NegotiateStreamConstants.HeaderLength;
161+
}
162+
163+
if (count == 0)
164+
return;
165+
166+
if (count + m_writeBuffer.Length > m_writePayloadLength)
167+
throw new InvalidDataException("Attempt to write more than a single message");
168+
169+
PayloadData payload;
170+
if (count < m_writePayloadLength)
171+
{
172+
m_writeBuffer.Write(buffer, offset, count);
173+
if (m_writeBuffer.Length < m_writePayloadLength)
174+
// The message is only partially written
175+
return;
176+
177+
var payloadBytes = m_writeBuffer.ToArray();
178+
payload = new PayloadData(new ArraySegment<byte>(payloadBytes, 0, (int) m_writeBuffer.Length));
179+
m_writeBuffer.SetLength(0);
180+
}
181+
else
182+
{
183+
// full payload provided
184+
payload = new PayloadData(new ArraySegment<byte>(buffer, offset, m_writePayloadLength));
185+
}
186+
await m_serverSession.SendReplyAsync(payload, m_ioBehavior, cancellationToken);
187+
// Need to parse NegotiateStream header next time
188+
m_writePayloadLength = 0;
189+
}
190+
191+
public override void Write(byte[] buffer, int offset, int count) => WriteAsync(buffer, offset, count, m_cancellationToken).Wait();
192+
193+
public override bool CanRead => true;
194+
195+
public override bool CanSeek => false;
196+
197+
public override bool CanWrite => true;
198+
199+
public override long Length => throw new NotImplementedException();
200+
201+
public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
202+
203+
public override void Flush()
204+
{
205+
}
206+
207+
public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
208+
209+
public override void SetLength(long value) => throw new NotImplementedException();
210+
211+
}
212+
internal class AuthGSSAPI
213+
{
214+
private static string GetServicePrincipalName(byte[] switchRequest)
215+
{
216+
var reader = new ByteArrayReader(switchRequest.AsSpan());
217+
return Encoding.UTF8.GetString(reader.ReadNullOrEofTerminatedByteString());
218+
}
219+
public static async Task<PayloadData> AuthenticateAsync(byte[] switchRequestPayloadData,
220+
ServerSession session, IOBehavior ioBehavior, CancellationToken cancellationToken)
221+
{
222+
using (var innerStream = new NegotiateToMySqlConverterStream(session, ioBehavior, cancellationToken))
223+
using (var negotiateStream = new NegotiateStream(innerStream))
224+
{
225+
var targetName = GetServicePrincipalName(switchRequestPayloadData);
226+
await negotiateStream.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, targetName);
227+
228+
if (innerStream.MySQLProtocolPayload.ArraySegment.Array != null)
229+
// return already pre-read OK packet.
230+
return innerStream.MySQLProtocolPayload;
231+
232+
// Read final OK packet from server
233+
return await session.ReceiveReplyAsync(ioBehavior, cancellationToken);
234+
}
235+
}
236+
}
237+
}

tests/SideBySide/AppConfig.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ public static class AppConfig
3636

3737
public static string PasswordlessUser => Config.GetValue<string>("Data:PasswordlessUser");
3838

39+
public static string GSSAPIUser => Config.GetValue<string>("Data:GSSAPIUser");
40+
3941
public static string SecondaryDatabase => Config.GetValue<string>("Data:SecondaryDatabase");
4042

4143
private static ServerFeatures UnsupportedFeatures => (ServerFeatures) Enum.Parse(typeof(ServerFeatures), Config.GetValue<string>("Data:UnsupportedFeatures"));
@@ -69,6 +71,14 @@ public static MySqlConnectionStringBuilder CreateCachingSha2ConnectionStringBuil
6971
return csb;
7072
}
7173

74+
public static MySqlConnectionStringBuilder CreateGSSAPIConnectionStringBuilder()
75+
{
76+
var csb = CreateConnectionStringBuilder();
77+
csb.UserID = GSSAPIUser;
78+
csb.Database = null;
79+
return csb;
80+
}
81+
7282
// tests can run much slower in CI environments
7383
public static int TimeoutDelayFactor { get; } = Environment.GetEnvironmentVariable("APPVEYOR") == "True" || Environment.GetEnvironmentVariable("TRAVIS") == "true" ? 6 : 1;
7484
}

tests/SideBySide/ConfigSettings.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ public enum ConfigSettings
1717
TcpConnection = 0x200,
1818
SecondaryDatabase = 0x400,
1919
KnownClientCertificate = 0x800,
20+
GSSAPIUser = 0x1000
2021
}
2122
}

tests/SideBySide/ConnectAsync.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,26 @@ public async Task CachingSha2WithoutSecureConnection()
300300
}
301301
}
302302

303+
// To create a MariaDB GSSAPI user for a current user
304+
// - install plugin if not already done , e.g mysql -uroot -e "INSTALL SONAME 'auth_gssapi'"
305+
// create MariaDB's gssapi user
306+
// a) Windows, easy way
307+
// mysql -uroot -e "CREATE USER %USERNAME% IDENTIFIED WITH gssapi"
308+
// b) more involved , Windows, outside of domain
309+
// mysql -uroot -e "CREATE USER gssapi_user IDENTIFIED WITH gssapi as '%USERDOMAIN%\%USERNAME%'"
310+
// c) Windows, inside domain
311+
// mysql -uroot -e "CREATE USER gssapi_user IDENTIFIED WITH gssapi as '%USERNAME%@%USERDNSDOMAIN%'"
312+
// d) Linux, domain (or Kerberos Realm) user
313+
// NAME=`klist|grep 'Default principal' |awk '{print $3}'` mysql -uroot -e "CREATE USER gssapi_user IDENTIFIED WITH gssapi AS '$NAME'"
314+
[SkippableFact(ConfigSettings.GSSAPIUser)]
315+
public async Task AuthGSSAPI()
316+
{
317+
var csb = AppConfig.CreateGSSAPIConnectionStringBuilder();
318+
using (var connection = new MySqlConnection(csb.ConnectionString))
319+
{
320+
await connection.OpenAsync();
321+
}
322+
}
303323
#if !BASELINE
304324
[Fact]
305325
public async Task PingNoConnection()

tests/SideBySide/TestUtilities.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ public static string GetSkipReason(ServerFeatures serverFeatures, ConfigSettings
9191
if (configSettings.HasFlag(ConfigSettings.PasswordlessUser) && string.IsNullOrWhiteSpace(AppConfig.PasswordlessUser))
9292
return "Requires PasswordlessUser in config.json";
9393

94+
if (configSettings.HasFlag(ConfigSettings.GSSAPIUser) && string.IsNullOrWhiteSpace(AppConfig.GSSAPIUser))
95+
return "Requires GSSAPIUser in config.json";
96+
9497
if (configSettings.HasFlag(ConfigSettings.CsvFile) && string.IsNullOrWhiteSpace(AppConfig.MySqlBulkLoaderCsvFile))
9598
return "Requires MySqlBulkLoaderCsvFile in config.json";
9699

0 commit comments

Comments
 (0)