-
Notifications
You must be signed in to change notification settings - Fork 341
Support auth_gssapi authentication. #577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
using System; | ||
using System.IO; | ||
using System.Net; | ||
using System.Net.Security; | ||
using System.Security.Authentication; | ||
using System.Text; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using MySqlConnector.Core; | ||
using MySqlConnector.Utilities; | ||
|
||
namespace MySqlConnector.Protocol.Serialization | ||
{ | ||
internal class NegotiateStreamConstants | ||
{ | ||
public const int HeaderLength = 5; | ||
public const byte MajorVersion = 1; | ||
public const byte MinorVersion = 0; | ||
public const byte HandshakeDone = 0x14; | ||
public const byte HandshakeError = 0x15; | ||
public const byte HandshakeInProgress = 0x16; | ||
public const ushort MaxPayloadLength = ushort.MaxValue; | ||
} | ||
|
||
/// <summary> | ||
/// Helper class to translate NegotiateStream framing for SPNEGO token | ||
/// into MySQL protocol packets. | ||
/// | ||
/// Serves as underlying stream for System.Net.NegotiateStream | ||
/// to perform MariaDB's auth_gssapi_client authentication. | ||
/// | ||
/// NegotiateStream protocol is described in e.g here | ||
/// https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-NNS/[MS-NNS].pdf | ||
/// We only use Handshake Messages for authentication. | ||
/// </summary> | ||
|
||
internal class NegotiateToMySqlConverterStream : Stream | ||
{ | ||
bool m_clientHandshakeDone; | ||
|
||
MemoryStream m_readBuffer; | ||
MemoryStream m_writeBuffer; | ||
int m_writePayloadLength; | ||
ServerSession m_serverSession; | ||
IOBehavior m_ioBehavior; | ||
CancellationToken m_cancellationToken; | ||
|
||
public PayloadData MySQLProtocolPayload { get; private set; } | ||
public NegotiateToMySqlConverterStream(ServerSession serverSession, IOBehavior ioBehavior, CancellationToken cancellationToken) | ||
{ | ||
m_serverSession = serverSession; | ||
m_readBuffer = new MemoryStream(); | ||
m_writeBuffer = new MemoryStream(); | ||
m_ioBehavior = ioBehavior; | ||
m_cancellationToken = cancellationToken; | ||
} | ||
|
||
static void CreateNegotiateStreamMessageHeader(byte[] buffer, int offset, byte messageId, long payloadLength) | ||
{ | ||
buffer[offset] = messageId; | ||
buffer[offset+1] = NegotiateStreamConstants.MajorVersion; | ||
buffer[offset+2] = NegotiateStreamConstants.MinorVersion; | ||
buffer[offset+3] = (byte) (payloadLength >> 8); | ||
buffer[offset+4] = (byte) (payloadLength & 0xff); | ||
} | ||
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | ||
{ | ||
var bytesRead = 0; | ||
|
||
if (m_readBuffer.Length == m_readBuffer.Position) | ||
{ | ||
if (count < NegotiateStreamConstants.HeaderLength) | ||
throw new InvalidDataException("Unexpected call to read less then NegotiateStream header"); | ||
|
||
if (m_clientHandshakeDone) | ||
{ | ||
// NegotiateStream protocol expects server to send "handshake done" | ||
// empty message at the end of handshake. | ||
CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeDone, 0); | ||
return NegotiateStreamConstants.HeaderLength; | ||
} | ||
// Read and cache packet from server. | ||
var payload = await m_serverSession.ReceiveReplyAsync(m_ioBehavior, cancellationToken).ConfigureAwait(false); | ||
var segment = payload.ArraySegment; | ||
|
||
if (segment.Count > NegotiateStreamConstants.MaxPayloadLength) | ||
throw new InvalidDataException(String.Format("Payload too big for NegotiateStream - {0} bytes", segment.Count)); | ||
|
||
// Check the first byte of the incoming packet. | ||
// It can be an OK packet indicating end of server processing, | ||
// or it can be 0x01 prefix we must strip off - 0x01 server masks special bytes, e.g 0xff, 0xfe in the payload | ||
// during pluggable authentication packet exchanges. | ||
var segmentOffset = segment.Offset; | ||
var segmentCount = segment.Count; | ||
|
||
switch (segment.Array[segment.Offset]) | ||
{ | ||
case 0x0: | ||
MySQLProtocolPayload = payload; | ||
CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeDone, 0); | ||
return NegotiateStreamConstants.HeaderLength; | ||
case 0x1: | ||
segmentOffset++; | ||
segmentCount--; | ||
break; | ||
} | ||
|
||
m_readBuffer = new MemoryStream(segment.Array, segmentOffset, segmentCount); | ||
CreateNegotiateStreamMessageHeader(buffer, offset, NegotiateStreamConstants.HandshakeInProgress, m_readBuffer.Length); | ||
bytesRead = NegotiateStreamConstants.HeaderLength; | ||
offset += bytesRead; | ||
count -= bytesRead; | ||
} | ||
if (count > 0) | ||
{ | ||
// Return cached data. | ||
bytesRead += m_readBuffer.Read(buffer, offset, count); | ||
} | ||
return bytesRead; | ||
} | ||
|
||
public override int Read(byte[] buffer, int offset, int count) => ReadAsync(buffer, offset, count, m_cancellationToken).GetAwaiter().GetResult(); | ||
|
||
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | ||
{ | ||
if (m_writePayloadLength == 0) | ||
{ | ||
// The message header was not read yet. | ||
if (count < NegotiateStreamConstants.HeaderLength) | ||
// For simplicity, we expect header to be written in one go | ||
throw new InvalidDataException("Cannot parse NegotiateStream handshake message header"); | ||
|
||
// Parse NegotiateStream handshake header | ||
var messageId = buffer[offset+0]; | ||
var majorProtocolVersion = buffer[offset+1]; | ||
var minorProtocolVersion = buffer[offset+2]; | ||
var payloadSizeLow = buffer[offset+4]; | ||
var payloadSizeHigh = buffer[offset+3]; | ||
|
||
|
||
if (majorProtocolVersion != NegotiateStreamConstants.MajorVersion || | ||
minorProtocolVersion != NegotiateStreamConstants.MinorVersion) | ||
{ | ||
throw new FormatException( | ||
String.Format("Unknown version of NegotiateStream protocol {0}.{1}, expected {2}.{3}", | ||
majorProtocolVersion, minorProtocolVersion, | ||
NegotiateStreamConstants.MajorVersion, NegotiateStreamConstants.MinorVersion)); | ||
} | ||
if (messageId != NegotiateStreamConstants.HandshakeDone && | ||
messageId != NegotiateStreamConstants.HandshakeError && | ||
messageId != NegotiateStreamConstants.HandshakeInProgress) | ||
{ | ||
throw new FormatException( | ||
String.Format("Invalid NegotiateStream MessageId 0x{0:X2}", messageId)); | ||
} | ||
|
||
m_writePayloadLength = (int) payloadSizeLow + ((int) payloadSizeHigh << 8); | ||
if (messageId == NegotiateStreamConstants.HandshakeDone) | ||
m_clientHandshakeDone = true; | ||
|
||
count -= NegotiateStreamConstants.HeaderLength; | ||
} | ||
|
||
if (count == 0) | ||
return; | ||
|
||
if (count + m_writeBuffer.Length > m_writePayloadLength) | ||
throw new InvalidDataException("Attempt to write more than a single message"); | ||
|
||
PayloadData payload; | ||
if (count < m_writePayloadLength) | ||
{ | ||
m_writeBuffer.Write(buffer, offset, count); | ||
if (m_writeBuffer.Length < m_writePayloadLength) | ||
// The message is only partially written | ||
return; | ||
|
||
var payloadBytes = m_writeBuffer.ToArray(); | ||
payload = new PayloadData(new ArraySegment<byte>(payloadBytes, 0, (int) m_writeBuffer.Length)); | ||
m_writeBuffer.SetLength(0); | ||
} | ||
else | ||
{ | ||
// full payload provided | ||
payload = new PayloadData(new ArraySegment<byte>(buffer, offset, m_writePayloadLength)); | ||
} | ||
await m_serverSession.SendReplyAsync(payload, m_ioBehavior, cancellationToken).ConfigureAwait(false); | ||
// Need to parse NegotiateStream header next time | ||
m_writePayloadLength = 0; | ||
} | ||
|
||
public override void Write(byte[] buffer, int offset, int count) => WriteAsync(buffer, offset, count, m_cancellationToken).GetAwaiter().GetResult(); | ||
|
||
public override bool CanRead => true; | ||
|
||
public override bool CanSeek => false; | ||
|
||
public override bool CanWrite => true; | ||
|
||
public override long Length => throw new NotImplementedException(); | ||
|
||
public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } | ||
|
||
public override void Flush() | ||
{ | ||
} | ||
|
||
public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException(); | ||
|
||
public override void SetLength(long value) => throw new NotImplementedException(); | ||
|
||
} | ||
internal class AuthGSSAPI | ||
{ | ||
private static string GetServicePrincipalName(byte[] switchRequest) | ||
{ | ||
var reader = new ByteArrayReader(switchRequest.AsSpan()); | ||
return Encoding.UTF8.GetString(reader.ReadNullOrEofTerminatedByteString()); | ||
} | ||
public static async Task<PayloadData> AuthenticateAsync(ConnectionSettings cs, byte[] switchRequestPayloadData, | ||
ServerSession session, IOBehavior ioBehavior, CancellationToken cancellationToken) | ||
{ | ||
using (var innerStream = new NegotiateToMySqlConverterStream(session, ioBehavior, cancellationToken)) | ||
using (var negotiateStream = new NegotiateStream(innerStream)) | ||
{ | ||
var targetName =cs.ServerSPN ?? GetServicePrincipalName(switchRequestPayloadData); | ||
#if NETSTANDARD1_3 | ||
await negotiateStream.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, targetName).ConfigureAwait(false); | ||
#else | ||
if (ioBehavior == IOBehavior.Synchronous) | ||
{ | ||
negotiateStream.AuthenticateAsClient(CredentialCache.DefaultNetworkCredentials, targetName); | ||
} | ||
else | ||
{ | ||
await negotiateStream.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, targetName).ConfigureAwait(false); | ||
} | ||
#endif | ||
if (cs.ServerSPN != null && !negotiateStream.IsMutuallyAuthenticated) | ||
{ | ||
// Negotiate used NTLM fallback, server name cannot be verified. | ||
throw new AuthenticationException(String.Format( | ||
"GSSAPI : Unable to verify server principal name using authentication type {0}", | ||
negotiateStream.RemoteIdentity?.AuthenticationType)); | ||
} | ||
if (innerStream.MySQLProtocolPayload.ArraySegment.Array != null) | ||
// return already pre-read OK packet. | ||
return innerStream.MySQLProtocolPayload; | ||
|
||
// Read final OK packet from server | ||
return await session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,7 @@ public void Defaults() | |
#if !BASELINE | ||
Assert.Null(csb.ServerRsaPublicKeyFile); | ||
#endif | ||
Assert.Null(csb.ServerSPN); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests also run against Oracle's Connector/NET (to verify equivalent behaviour) so this line will need to be moved inside the |
||
#if !BASELINE | ||
Assert.Equal(MySqlSslMode.Preferred, csb.SslMode); | ||
#else | ||
|
@@ -121,6 +122,7 @@ public void ParseConnectionString() | |
"Port=1234;" + | ||
"protocol=pipe;" + | ||
"pwd=Pass1234;" + | ||
"server spn=mariadb/[email protected];" + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move up to line 111 (to be inside |
||
"Treat Tiny As Boolean=false;" + | ||
"ssl mode=verifyca;" + | ||
"Uid=username;" + | ||
|
@@ -168,6 +170,7 @@ public void ParseConnectionString() | |
Assert.False(csb.Pooling); | ||
Assert.Equal(1234u, csb.Port); | ||
Assert.Equal("db-server", csb.Server); | ||
Assert.Equal("mariadb/[email protected]", csb.ServerSPN); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move up to line 159 (to be inside |
||
Assert.False(csb.TreatTinyAsBoolean); | ||
Assert.Equal(MySqlSslMode.VerifyCA, csb.SslMode); | ||
Assert.False(csb.UseAffectedRows); | ||
|
Uh oh!
There was an error while loading. Please reload this page.