Skip to content

Improve usage of Type.GetType when activating types in data protection #54256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 13, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;

namespace Microsoft.AspNetCore.DataProtection.Internal;

internal sealed class DefaultTypeNameResolver : ITypeNameResolver
{
public static readonly DefaultTypeNameResolver Instance = new();

private DefaultTypeNameResolver()
{
}

[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType is only used to resolve statically known types that are referenced by DataProtection assembly.")]
public bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type)
{
try
{
// Some exceptions are thrown regardless of the value of throwOnError.
// For example, if the type is found but cannot be loaded,
// a System.TypeLoadException is thrown even if throwOnError is false.
type = Type.GetType(typeName, throwOnError: false);
return type != null;
}
catch
{
type = null;
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;

namespace Microsoft.AspNetCore.DataProtection.Internal;

internal interface ITypeNameResolver
{
bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type);
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public sealed class XmlKeyManager : IKeyManager, IInternalXmlKeyManager
private const string RevokeAllKeysValue = "*";

private readonly IActivator _activator;
private readonly ITypeNameResolver _typeNameResolver;
private readonly AlgorithmConfiguration _authenticatedEncryptorConfiguration;
private readonly IKeyEscrowSink? _keyEscrowSink;
private readonly IInternalXmlKeyManager _internalKeyManager;
Expand Down Expand Up @@ -112,6 +113,8 @@ internal XmlKeyManager(
var escrowSinks = keyManagementOptions.Value.KeyEscrowSinks;
_keyEscrowSink = escrowSinks.Count > 0 ? new AggregateKeyEscrowSink(escrowSinks) : null;
_activator = activator;
// Note: ITypeNameResolver is only implemented on the activator in tests. In production, it's always DefaultTypeNameResolver.
_typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance;
TriggerAndResetCacheExpirationToken(suppressLogging: true);
_internalKeyManager = _internalKeyManager ?? this;
_encryptorFactories = keyManagementOptions.Value.AuthenticatedEncryptorFactories;
Expand Down Expand Up @@ -463,27 +466,27 @@ IAuthenticatedEncryptorDescriptor IInternalXmlKeyManager.DeserializeDescriptorFr
}
}

[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")]
private IAuthenticatedEncryptorDescriptorDeserializer CreateDeserializer(string descriptorDeserializerTypeName)
{
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(descriptorDeserializerTypeName, out var forwardedTypeName)
// typeNameToMatch will be used for matching against known types but not passed to the activator.
// The activator will do its own forwarding.
var typeNameToMatch = TypeForwardingActivator.TryForwardTypeName(descriptorDeserializerTypeName, out var forwardedTypeName)
? forwardedTypeName
: descriptorDeserializerTypeName;
var type = Type.GetType(resolvedTypeName, throwOnError: false);

if (type == typeof(AuthenticatedEncryptorDescriptorDeserializer))
if (typeof(AuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver))
{
return _activator.CreateInstance<AuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
}
else if (type == typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver))
{
return _activator.CreateInstance<CngCbcAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
}
else if (type == typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver))
{
return _activator.CreateInstance<CngGcmAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
}
else if (type == typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer))
else if (typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver))
{
return _activator.CreateInstance<ManagedAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
}
Expand Down
13 changes: 13 additions & 0 deletions src/DataProtection/DataProtection/src/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Diagnostics.CodeAnalysis;
using Microsoft.AspNetCore.DataProtection.Internal;

namespace Microsoft.AspNetCore.DataProtection;

Expand Down Expand Up @@ -39,4 +40,16 @@ public static Type GetTypeWithTrimFriendlyErrorMessage(string typeName)
throw new InvalidOperationException($"Unable to load type '{typeName}'. If the app is published with trimming then this type may have been trimmed. Ensure the type's assembly is excluded from trimming.", ex);
}
}

public static bool MatchName(this Type matchType, string resolvedTypeName, ITypeNameResolver typeNameResolver)
{
// Before attempting to resolve the name to a type, check if it starts with the full name of the type.
// Use StartsWith to ignore potential assembly version differences.
if (matchType.FullName != null && resolvedTypeName.StartsWith(matchType.FullName, StringComparison.Ordinal))
{
return typeNameResolver.TryResolveType(resolvedTypeName, out var resolvedType) && resolvedType == matchType;
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,30 @@ public static XElement DecryptElement(this XElement element, IActivator activato
return doc.Root!;
}

[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")]
private static IXmlDecryptor CreateDecryptor(IActivator activator, string decryptorTypeName)
{
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName)
// typeNameToMatch will be used for matching against known types but not passed to the activator.
// The activator will do its own forwarding.
var typeNameToMatch = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName)
? forwardedTypeName
: decryptorTypeName;
var type = Type.GetType(resolvedTypeName, throwOnError: false);

if (type == typeof(DpapiNGXmlDecryptor))
// Note: ITypeNameResolver is only implemented on the activator in tests. In production, it's always DefaultTypeNameResolver.
var typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance;

if (typeof(DpapiNGXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver))
{
return activator.CreateInstance<DpapiNGXmlDecryptor>(decryptorTypeName);
}
else if (type == typeof(DpapiXmlDecryptor))
else if (typeof(DpapiXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver))
{
return activator.CreateInstance<DpapiXmlDecryptor>(decryptorTypeName);
}
else if (type == typeof(EncryptedXmlDecryptor))
else if (typeof(EncryptedXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver))
{
return activator.CreateInstance<EncryptedXmlDecryptor>(decryptorTypeName);
}
else if (type == typeof(NullXmlDecryptor))
else if (typeof(NullXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver))
{
return activator.CreateInstance<NullXmlDecryptor>(decryptorTypeName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,100 @@ public void DecryptElement_RootNodeRequiresDecryption_Success()
XmlAssert.Equal("<newNode />", retVal);
}

[Fact]
public void DecryptElement_CustomType_TypeNameResolverNotCalled()
{
// Arrange
var decryptorTypeName = typeof(MyXmlDecryptor).AssemblyQualifiedName;

var original = XElement.Parse(@$"
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
<node />
</x:encryptedSecret>");

var mockActivator = new Mock<IActivator>();
mockActivator.ReturnDecryptedElementGivenDecryptorTypeNameAndInput(decryptorTypeName, "<node />", "<newNode />");
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();

var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
var services = serviceCollection.BuildServiceProvider();
var activator = services.GetActivator();

// Act
var retVal = original.DecryptElement(activator);

// Assert
XmlAssert.Equal("<newNode />", retVal);
Type resolvedType;
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Never());
}

[Fact]
public void DecryptElement_KnownType_TypeNameResolverCalled()
{
// Arrange
var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName;
TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName);

var original = XElement.Parse(@$"
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
<node>
<value />
</node>
</x:encryptedSecret>");

var mockActivator = new Mock<IActivator>();
mockActivator.Setup(o => o.CreateInstance(typeof(NullXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor());
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
var resolvedType = typeof(NullXmlDecryptor);
mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(forwardedTypeName, out resolvedType)).Returns(true);

var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
var services = serviceCollection.BuildServiceProvider();
var activator = services.GetActivator();

// Act
var retVal = original.DecryptElement(activator);

// Assert
XmlAssert.Equal("<value />", retVal);
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Once());
}

[Fact]
public void DecryptElement_KnownType_UnableToResolveType_Success()
{
// Arrange
var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName;

var original = XElement.Parse(@$"
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
<node>
<value />
</node>
</x:encryptedSecret>");

var mockActivator = new Mock<IActivator>();
mockActivator.Setup(o => o.CreateInstance(typeof(IXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor());
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
Type resolvedType = null;
mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(It.IsAny<string>(), out resolvedType)).Returns(false);

var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
var services = serviceCollection.BuildServiceProvider();
var activator = services.GetActivator();

// Act
var retVal = original.DecryptElement(activator);

// Assert
XmlAssert.Equal("<value />", retVal);
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Once());
}

[Fact]
public void DecryptElement_MultipleNodesRequireDecryption_AvoidsRecursion_Success()
{
Expand Down