Skip to content

Fix static proxy serialization #1711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion src/NHibernate.Test/StaticProxyTest/StaticProxyFactoryFixture.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.Serialization.Formatters.Binary;
using NHibernate.Proxy;
using NUnit.Framework;

Expand All @@ -16,6 +19,12 @@ public class TestClass : ISomething
public virtual int Id { get; set; }
}

[Serializable]
public class SimpleTestClass
{
public virtual int Id { get; set; }
}

[Test]
public void CanCreateProxyForClassWithInternalInterface()
{
Expand All @@ -24,5 +33,61 @@ public void CanCreateProxyForClassWithInternalInterface()
var proxy = factory.GetProxy(1, null);
Assert.That(proxy, Is.Not.Null);
}

[Test]
public void InitializedProxyStaysInitializedAfterDeserialization()
{
TestsContext.AssumeSystemTypeIsSerializable();

var factory = new StaticProxyFactory();
factory.PostInstantiate(typeof(SimpleTestClass).FullName, typeof(SimpleTestClass), new HashSet<System.Type> {typeof(INHibernateProxy)}, null, null, null);
var proxy = factory.GetProxy(2, null);
Assert.That(proxy, Is.Not.Null, "proxy");
Assert.That(NHibernateUtil.IsInitialized(proxy), Is.False, "proxy already initialized after creation");
Assert.That(proxy.HibernateLazyInitializer, Is.Not.Null, "HibernateLazyInitializer");

var impl = new SimpleTestClass { Id = 2 };
proxy.HibernateLazyInitializer.SetImplementation(impl);
Assert.That(NHibernateUtil.IsInitialized(proxy), Is.True, "proxy not initialized after setting implementation");

var serializer = new BinaryFormatter();
object deserialized;
using (var memoryStream = new MemoryStream())
{
serializer.Serialize(memoryStream, proxy);
memoryStream.Seek(0L, SeekOrigin.Begin);
deserialized = serializer.Deserialize(memoryStream);
}
Assert.That(deserialized, Is.Not.Null, "deserialized");
Assert.That(deserialized, Is.InstanceOf<INHibernateProxy>());
Assert.That(NHibernateUtil.IsInitialized(deserialized), Is.True, "proxy no more initialized after deserialization");
Assert.That(deserialized, Is.InstanceOf<SimpleTestClass>());
Assert.That(((SimpleTestClass) deserialized).Id, Is.EqualTo(2));
}

[Test]
public void NonInitializedProxyStaysNonInitializedAfterSerialization()
{
TestsContext.AssumeSystemTypeIsSerializable();

var factory = new StaticProxyFactory();
factory.PostInstantiate(typeof(SimpleTestClass).FullName, typeof(SimpleTestClass), new HashSet<System.Type> {typeof(INHibernateProxy)}, null, null, null);
var proxy = factory.GetProxy(2, null);
Assert.That(proxy, Is.Not.Null, "proxy");
Assert.That(NHibernateUtil.IsInitialized(proxy), Is.False, "proxy already initialized after creation");

var serializer = new BinaryFormatter();
object deserialized;
using (var memoryStream = new MemoryStream())
{
serializer.Serialize(memoryStream, proxy);
Assert.That(NHibernateUtil.IsInitialized(proxy), Is.False, "proxy initialized after serialization");
memoryStream.Seek(0L, SeekOrigin.Begin);
deserialized = serializer.Deserialize(memoryStream);
}
Assert.That(deserialized, Is.Not.Null, "deserialized");
Assert.That(deserialized, Is.InstanceOf<INHibernateProxy>());
Assert.That(NHibernateUtil.IsInitialized(deserialized), Is.False, "proxy initialized after deserialization");
}
}
}
39 changes: 36 additions & 3 deletions src/NHibernate/Proxy/NHibernateProxyBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ internal class NHibernateProxyBuilder
private static readonly PropertyInfo LazyInitializerIdentifierProperty = LazyInitializerType.GetProperty(nameof(ILazyInitializer.Identifier));
private static readonly MethodInfo LazyInitializerInitializeMethod = LazyInitializerType.GetMethod(nameof(ILazyInitializer.Initialize));
private static readonly MethodInfo LazyInitializerGetImplementationMethod = LazyInitializerType.GetMethod(nameof(ILazyInitializer.GetImplementation), System.Type.EmptyTypes);
private static readonly PropertyInfo LazyInitializerIsUninitializedProperty = LazyInitializerType.GetProperty(nameof(ILazyInitializer.IsUninitialized));
private static readonly IProxyAssemblyBuilder ProxyAssemblyBuilder = new DefaultProxyAssemblyBuilder();

private static readonly ConstructorInfo SecurityCriticalAttributeConstructor = typeof(SecurityCriticalAttribute).GetConstructor(System.Type.EmptyTypes);
Expand Down Expand Up @@ -95,7 +96,7 @@ public TypeInfo CreateProxyType(System.Type baseType, IReadOnlyCollection<System
var customAttributeBuilder = new CustomAttributeBuilder(serializableConstructor, Array.Empty<object>());
typeBuilder.SetCustomAttribute(customAttributeBuilder);

ImplementDeserializationConstructor(typeBuilder);
ImplementDeserializationConstructor(typeBuilder, parentType);
ImplementGetObjectData(typeBuilder, proxyInfoField, lazyInitializerField);

var proxyType = typeBuilder.CreateTypeInfo();
Expand Down Expand Up @@ -168,13 +169,24 @@ private static void ImplementConstructor(TypeBuilder typeBuilder, System.Type pa
IL.Emit(OpCodes.Ret);
}

private static void ImplementDeserializationConstructor(TypeBuilder typeBuilder)
private static void ImplementDeserializationConstructor(TypeBuilder typeBuilder, System.Type parentType)
{
var parameterTypes = new[] {typeof (SerializationInfo), typeof (StreamingContext)};
var constructor = typeBuilder.DefineConstructor(constructorAttributes, CallingConventions.Standard, parameterTypes);
constructor.SetImplementationFlags(MethodImplAttributes.IL | MethodImplAttributes.Managed);

var IL = constructor.GetILGenerator();

constructor.SetImplementationFlags(MethodImplAttributes.IL | MethodImplAttributes.Managed);

var baseConstructor = parentType.GetConstructor(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public, null, System.Type.EmptyTypes, null);
// if there is no default constructor, or the default constructor is private/internal, call System.Object constructor
// this works, but the generated assembly will fail PeVerify (cannot use in medium trust for example)
if (baseConstructor == null || baseConstructor.IsPrivate || baseConstructor.IsAssembly)
baseConstructor = ObjectConstructor;
IL.Emit(OpCodes.Ldarg_0);
IL.Emit(OpCodes.Call, baseConstructor);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have checked PeVerify with following test added in PeVerifyFixture:

[Test]
public void VerifyStaticProxy()
{
	var proxyBuilderType = typeof(StaticProxyFactory).Assembly.GetType("NHibernate.Proxy.NHibernateProxyBuilder", true);
	var proxyBuilder = proxyBuilderType.GetMethod("CreateProxyType");
	Assert.That(proxyBuilder, Is.Not.Null, "Failed to find method CreateProxyType");
	var proxyBuilderAssemblyBuilder = proxyBuilderType.GetField("ProxyAssemblyBuilder", BindingFlags.NonPublic | BindingFlags.Static);
	Assert.That(proxyBuilderAssemblyBuilder, Is.Not.Null, "Failed to find assembly builder field");
	var builderCtor = proxyBuilderType.GetConstructor(
		new[] { typeof(MethodInfo), typeof(MethodInfo), typeof(IAbstractComponentType), typeof(bool) });
	Assert.That(builderCtor, Is.Not.Null, "Failed to find builder ctor");
	var builder = builderCtor.Invoke(new object[] { null, null, null, false });

	var assemblyBuilder = new SavingProxyAssemblyBuilder(assemblyName);
	var backupAssemblyBuilder = proxyBuilderAssemblyBuilder.GetValue(null);
	proxyBuilderAssemblyBuilder.SetValue(null, assemblyBuilder);
	try
	{
		proxyBuilder.Invoke(builder, new object[] { typeof(ClassWithPublicDefaultConstructor), new[] { typeof(INHibernateProxy) } });
	}
	finally
	{
		proxyBuilderAssemblyBuilder.SetValue(null, backupAssemblyBuilder);
	}

	new PeVerifier($"{assemblyName}.dll").AssertIsValid();
}

I have not added it in the PR as testing this would be much easier with #1709, without needs to depend on dynamic proxies test classes.


//Everything is done in NHibernateProxyObjectReference, so just return data.
IL.Emit(OpCodes.Ret);
}
Expand All @@ -199,7 +211,12 @@ private static void ImplementGetObjectData(TypeBuilder typeBuilder, FieldInfo pr
IL.Emit(OpCodes.Call, ReflectionCache.TypeMethods.GetTypeFromHandle);
IL.Emit(OpCodes.Callvirt, SerializationInfoSetTypeMethod);

// (new NHibernateProxyObjectReference(this.__proxyInfo, this.__lazyInitializer.Identifier)).GetObjectData(info, context);
// return
// (new NHibernateProxyObjectReference(
// this.__proxyInfo,
// this.__lazyInitializer.Identifier),
// this.__lazyInitializer.IsUninitialized ? null : this.__lazyInitializer.GetImplementation())
// .GetObjectData(info, context);
//this.__proxyInfo
IL.Emit(OpCodes.Ldarg_0);
IL.Emit(OpCodes.Ldfld, proxyInfoField);
Expand All @@ -209,11 +226,27 @@ private static void ImplementGetObjectData(TypeBuilder typeBuilder, FieldInfo pr
IL.Emit(OpCodes.Ldfld, lazyInitializerField);
IL.Emit(OpCodes.Callvirt, LazyInitializerIdentifierProperty.GetMethod);

// this.__lazyInitializer.IsUninitialized ? null : this.__lazyInitializer.GetImplementation()
var isUnitialized = IL.DefineLabel();
var endIsUnitializedTernary = IL.DefineLabel();
IL.Emit(OpCodes.Ldarg_0);
IL.Emit(OpCodes.Ldfld, lazyInitializerField);
IL.Emit(OpCodes.Callvirt, LazyInitializerIsUninitializedProperty.GetMethod);
IL.Emit(OpCodes.Brtrue, isUnitialized);
IL.Emit(OpCodes.Ldarg_0);
IL.Emit(OpCodes.Ldfld, lazyInitializerField);
IL.Emit(OpCodes.Callvirt, LazyInitializerGetImplementationMethod);
IL.Emit(OpCodes.Br, endIsUnitializedTernary);
IL.MarkLabel(isUnitialized);
IL.Emit(OpCodes.Ldnull);
IL.MarkLabel(endIsUnitializedTernary);

var constructor = typeof(NHibernateProxyObjectReference).GetConstructor(
new[]
{
typeof(NHibernateProxyFactoryInfo),
typeof(object),
typeof(object)
});
IL.Emit(OpCodes.Newobj, constructor);

Expand Down
23 changes: 22 additions & 1 deletion src/NHibernate/Proxy/NHibernateProxyObjectReference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,51 @@ public sealed class NHibernateProxyObjectReference : IObjectReference, ISerializ
{
private readonly NHibernateProxyFactoryInfo _proxyFactoryInfo;
private readonly object _identifier;
private readonly object _implementation;

[Obsolete("Use overload taking an implementation parameter")]
public NHibernateProxyObjectReference(NHibernateProxyFactoryInfo proxyFactoryInfo, object identifier)
: this (proxyFactoryInfo, identifier, null)
{}

public NHibernateProxyObjectReference(NHibernateProxyFactoryInfo proxyFactoryInfo, object identifier, object implementation)
{
_proxyFactoryInfo = proxyFactoryInfo;
_identifier = identifier;
_implementation = implementation;
}

private NHibernateProxyObjectReference(SerializationInfo info, StreamingContext context)
{
_proxyFactoryInfo = (NHibernateProxyFactoryInfo) info.GetValue(nameof(_proxyFactoryInfo), typeof(NHibernateProxyFactoryInfo));
_identifier = info.GetValue(nameof(_identifier), typeof(object));
// 6.0 TODO: simplify with info.GetValue(nameof(_implementation), typeof(object));
foreach (var entry in info)
{
if (entry.Name == nameof(_implementation))
{
_implementation = entry.Value;
}
}
}

[SecurityCritical]
public object GetRealObject(StreamingContext context)
{
return _proxyFactoryInfo.CreateProxyFactory().GetProxy(_identifier, null);
var proxy = _proxyFactoryInfo.CreateProxyFactory().GetProxy(_identifier, null);

if (_implementation != null)
proxy.HibernateLazyInitializer.SetImplementation(_implementation);

return proxy;
}

[SecurityCritical]
public void GetObjectData(SerializationInfo info, StreamingContext context)
{
info.AddValue(nameof(_proxyFactoryInfo), _proxyFactoryInfo);
info.AddValue(nameof(_identifier), _identifier);
info.AddValue(nameof(_implementation), _implementation);
}
}
}