Skip to content

Fix possible InvalidCastException in ActionQueue #1753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/NHibernate/Action/AbstractEntityInsertAction.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using System;
using NHibernate.Engine;
using NHibernate.Persister.Entity;

namespace NHibernate.Action
{
[Serializable]
public abstract class AbstractEntityInsertAction : EntityAction
{
protected internal AbstractEntityInsertAction(
object id,
object[] state,
object instance,
IEntityPersister persister,
ISessionImplementor session)
: base(session, id, instance, persister)
{
State = state;
}

public object[] State { get; }
}
}
16 changes: 7 additions & 9 deletions src/NHibernate/Action/EntityIdentityInsertAction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@
namespace NHibernate.Action
{
[Serializable]
public sealed partial class EntityIdentityInsertAction : EntityAction
public sealed partial class EntityIdentityInsertAction : AbstractEntityInsertAction
{
private readonly object lockObject = new object();
private readonly object[] state;
private readonly bool isDelayed;
private readonly EntityKey delayedEntityKey;
//private CacheEntry cacheEntry;
private object generatedId;

public EntityIdentityInsertAction(object[] state, object instance, IEntityPersister persister, ISessionImplementor session, bool isDelayed)
: base(session, null, instance, persister)
: base(null, state, instance, persister, session)
{
this.state = state;
this.isDelayed = isDelayed;
delayedEntityKey = this.isDelayed ? GenerateDelayedEntityKey() : null;
}
Expand Down Expand Up @@ -72,10 +70,10 @@ public override void Execute()

if (!veto)
{
generatedId = persister.Insert(state, instance, Session);
generatedId = persister.Insert(State, instance, Session);
if (persister.HasInsertGeneratedProperties)
{
persister.ProcessInsertGeneratedProperties(generatedId, instance, state, Session);
persister.ProcessInsertGeneratedProperties(generatedId, instance, State, Session);
}
//need to do that here rather than in the save event listener to let
//the post insert events to have a id-filled entity when IDENTITY is used (EJB3)
Expand Down Expand Up @@ -107,7 +105,7 @@ private void PostInsert()
IPostInsertEventListener[] postListeners = Session.Listeners.PostInsertEventListeners;
if (postListeners.Length > 0)
{
PostInsertEvent postEvent = new PostInsertEvent(Instance, generatedId, state, Persister, (IEventSource)Session);
PostInsertEvent postEvent = new PostInsertEvent(Instance, generatedId, State, Persister, (IEventSource)Session);
foreach (IPostInsertEventListener listener in postListeners)
{
listener.OnPostInsert(postEvent);
Expand All @@ -120,7 +118,7 @@ private void PostCommitInsert()
IPostInsertEventListener[] postListeners = Session.Listeners.PostCommitInsertEventListeners;
if (postListeners.Length > 0)
{
var postEvent = new PostInsertEvent(Instance, generatedId, state, Persister, (IEventSource) Session);
var postEvent = new PostInsertEvent(Instance, generatedId, State, Persister, (IEventSource) Session);
foreach (IPostInsertEventListener listener in postListeners)
{
listener.OnPostInsert(postEvent);
Expand All @@ -134,7 +132,7 @@ private bool PreInsert()
bool veto = false;
if (preListeners.Length > 0)
{
var preEvent = new PreInsertEvent(Instance, null, state, Persister, (IEventSource) Session);
var preEvent = new PreInsertEvent(Instance, null, State, Persister, (IEventSource) Session);
foreach (IPreInsertEventListener listener in preListeners)
{
veto |= listener.OnPreInsert(preEvent);
Expand Down
27 changes: 10 additions & 17 deletions src/NHibernate/Action/EntityInsertAction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,17 @@
namespace NHibernate.Action
{
[Serializable]
public sealed partial class EntityInsertAction : EntityAction
public sealed partial class EntityInsertAction : AbstractEntityInsertAction
{
private readonly object[] state;
private object version;
private object cacheEntry;

public EntityInsertAction(object id, object[] state, object instance, object version, IEntityPersister persister, ISessionImplementor session)
: base(session, id, instance, persister)
: base(id, state, instance, persister, session)
{
this.state = state;
this.version = version;
}

public object[] State
{
get { return state; }
}

protected internal override bool HasPostCommitEventListeners
{
get
Expand Down Expand Up @@ -56,7 +49,7 @@ public override void Execute()
if (!veto)
{

persister.Insert(id, state, instance, Session);
persister.Insert(id, State, instance, Session);

EntityEntry entry = Session.PersistenceContext.GetEntry(instance);
if (entry == null)
Expand All @@ -68,20 +61,20 @@ public override void Execute()

if (persister.HasInsertGeneratedProperties)
{
persister.ProcessInsertGeneratedProperties(id, instance, state, Session);
persister.ProcessInsertGeneratedProperties(id, instance, State, Session);
if (persister.IsVersionPropertyGenerated)
{
version = Versioning.GetVersion(state, persister);
version = Versioning.GetVersion(State, persister);
}
entry.PostUpdate(instance, state, version);
entry.PostUpdate(instance, State, version);
}
}

ISessionFactoryImplementor factory = Session.Factory;

if (IsCachePutEnabled(persister))
{
CacheEntry ce = new CacheEntry(state, persister, persister.HasUninitializedLazyProperties(instance), version, session, instance);
CacheEntry ce = new CacheEntry(State, persister, persister.HasUninitializedLazyProperties(instance), version, session, instance);
cacheEntry = persister.CacheEntryStructure.Structure(ce);

CacheKey ck = Session.GenerateCacheKey(id, persister.IdentifierType, persister.RootEntityName);
Expand Down Expand Up @@ -127,7 +120,7 @@ private void PostInsert()
IPostInsertEventListener[] postListeners = Session.Listeners.PostInsertEventListeners;
if (postListeners.Length > 0)
{
PostInsertEvent postEvent = new PostInsertEvent(Instance, Id, state, Persister, (IEventSource)Session);
PostInsertEvent postEvent = new PostInsertEvent(Instance, Id, State, Persister, (IEventSource)Session);
foreach (IPostInsertEventListener listener in postListeners)
{
listener.OnPostInsert(postEvent);
Expand All @@ -140,7 +133,7 @@ private void PostCommitInsert()
IPostInsertEventListener[] postListeners = Session.Listeners.PostCommitInsertEventListeners;
if (postListeners.Length > 0)
{
PostInsertEvent postEvent = new PostInsertEvent(Instance, Id, state, Persister, (IEventSource)Session);
PostInsertEvent postEvent = new PostInsertEvent(Instance, Id, State, Persister, (IEventSource)Session);
foreach (IPostInsertEventListener listener in postListeners)
{
listener.OnPostInsert(postEvent);
Expand All @@ -154,7 +147,7 @@ private bool PreInsert()
bool veto = false;
if (preListeners.Length > 0)
{
var preEvent = new PreInsertEvent(Instance, Id, state, Persister, (IEventSource) Session);
var preEvent = new PreInsertEvent(Instance, Id, State, Persister, (IEventSource) Session);
foreach (IPreInsertEventListener listener in preListeners)
{
veto |= listener.OnPreInsert(preEvent);
Expand Down
12 changes: 6 additions & 6 deletions src/NHibernate/Async/Action/EntityIdentityInsertAction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace NHibernate.Action
{
using System.Threading.Tasks;
using System.Threading;
public sealed partial class EntityIdentityInsertAction : EntityAction
public sealed partial class EntityIdentityInsertAction : AbstractEntityInsertAction
{

public override async Task ExecuteAsync(CancellationToken cancellationToken)
Expand All @@ -41,10 +41,10 @@ public override async Task ExecuteAsync(CancellationToken cancellationToken)

if (!veto)
{
generatedId = await (persister.InsertAsync(state, instance, Session, cancellationToken)).ConfigureAwait(false);
generatedId = await (persister.InsertAsync(State, instance, Session, cancellationToken)).ConfigureAwait(false);
if (persister.HasInsertGeneratedProperties)
{
await (persister.ProcessInsertGeneratedPropertiesAsync(generatedId, instance, state, Session, cancellationToken)).ConfigureAwait(false);
await (persister.ProcessInsertGeneratedPropertiesAsync(generatedId, instance, State, Session, cancellationToken)).ConfigureAwait(false);
}
//need to do that here rather than in the save event listener to let
//the post insert events to have a id-filled entity when IDENTITY is used (EJB3)
Expand Down Expand Up @@ -77,7 +77,7 @@ private async Task PostInsertAsync(CancellationToken cancellationToken)
IPostInsertEventListener[] postListeners = Session.Listeners.PostInsertEventListeners;
if (postListeners.Length > 0)
{
PostInsertEvent postEvent = new PostInsertEvent(Instance, generatedId, state, Persister, (IEventSource)Session);
PostInsertEvent postEvent = new PostInsertEvent(Instance, generatedId, State, Persister, (IEventSource)Session);
foreach (IPostInsertEventListener listener in postListeners)
{
await (listener.OnPostInsertAsync(postEvent, cancellationToken)).ConfigureAwait(false);
Expand All @@ -91,7 +91,7 @@ private async Task PostCommitInsertAsync(CancellationToken cancellationToken)
IPostInsertEventListener[] postListeners = Session.Listeners.PostCommitInsertEventListeners;
if (postListeners.Length > 0)
{
var postEvent = new PostInsertEvent(Instance, generatedId, state, Persister, (IEventSource) Session);
var postEvent = new PostInsertEvent(Instance, generatedId, State, Persister, (IEventSource) Session);
foreach (IPostInsertEventListener listener in postListeners)
{
await (listener.OnPostInsertAsync(postEvent, cancellationToken)).ConfigureAwait(false);
Expand All @@ -106,7 +106,7 @@ private async Task<bool> PreInsertAsync(CancellationToken cancellationToken)
bool veto = false;
if (preListeners.Length > 0)
{
var preEvent = new PreInsertEvent(Instance, null, state, Persister, (IEventSource) Session);
var preEvent = new PreInsertEvent(Instance, null, State, Persister, (IEventSource) Session);
foreach (IPreInsertEventListener listener in preListeners)
{
veto |= await (listener.OnPreInsertAsync(preEvent, cancellationToken)).ConfigureAwait(false);
Expand Down
18 changes: 9 additions & 9 deletions src/NHibernate/Async/Action/EntityInsertAction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace NHibernate.Action
{
using System.Threading.Tasks;
using System.Threading;
public sealed partial class EntityInsertAction : EntityAction
public sealed partial class EntityInsertAction : AbstractEntityInsertAction
{

public override async Task ExecuteAsync(CancellationToken cancellationToken)
Expand All @@ -45,7 +45,7 @@ public override async Task ExecuteAsync(CancellationToken cancellationToken)
if (!veto)
{

await (persister.InsertAsync(id, state, instance, Session, cancellationToken)).ConfigureAwait(false);
await (persister.InsertAsync(id, State, instance, Session, cancellationToken)).ConfigureAwait(false);

EntityEntry entry = Session.PersistenceContext.GetEntry(instance);
if (entry == null)
Expand All @@ -57,20 +57,20 @@ public override async Task ExecuteAsync(CancellationToken cancellationToken)

if (persister.HasInsertGeneratedProperties)
{
await (persister.ProcessInsertGeneratedPropertiesAsync(id, instance, state, Session, cancellationToken)).ConfigureAwait(false);
await (persister.ProcessInsertGeneratedPropertiesAsync(id, instance, State, Session, cancellationToken)).ConfigureAwait(false);
if (persister.IsVersionPropertyGenerated)
{
version = Versioning.GetVersion(state, persister);
version = Versioning.GetVersion(State, persister);
}
entry.PostUpdate(instance, state, version);
entry.PostUpdate(instance, State, version);
}
}

ISessionFactoryImplementor factory = Session.Factory;

if (IsCachePutEnabled(persister))
{
CacheEntry ce = new CacheEntry(state, persister, persister.HasUninitializedLazyProperties(instance), version, session, instance);
CacheEntry ce = new CacheEntry(State, persister, persister.HasUninitializedLazyProperties(instance), version, session, instance);
cacheEntry = persister.CacheEntryStructure.Structure(ce);

CacheKey ck = Session.GenerateCacheKey(id, persister.IdentifierType, persister.RootEntityName);
Expand Down Expand Up @@ -118,7 +118,7 @@ private async Task PostInsertAsync(CancellationToken cancellationToken)
IPostInsertEventListener[] postListeners = Session.Listeners.PostInsertEventListeners;
if (postListeners.Length > 0)
{
PostInsertEvent postEvent = new PostInsertEvent(Instance, Id, state, Persister, (IEventSource)Session);
PostInsertEvent postEvent = new PostInsertEvent(Instance, Id, State, Persister, (IEventSource)Session);
foreach (IPostInsertEventListener listener in postListeners)
{
await (listener.OnPostInsertAsync(postEvent, cancellationToken)).ConfigureAwait(false);
Expand All @@ -132,7 +132,7 @@ private async Task PostCommitInsertAsync(CancellationToken cancellationToken)
IPostInsertEventListener[] postListeners = Session.Listeners.PostCommitInsertEventListeners;
if (postListeners.Length > 0)
{
PostInsertEvent postEvent = new PostInsertEvent(Instance, Id, state, Persister, (IEventSource)Session);
PostInsertEvent postEvent = new PostInsertEvent(Instance, Id, State, Persister, (IEventSource)Session);
foreach (IPostInsertEventListener listener in postListeners)
{
await (listener.OnPostInsertAsync(postEvent, cancellationToken)).ConfigureAwait(false);
Expand All @@ -147,7 +147,7 @@ private async Task<bool> PreInsertAsync(CancellationToken cancellationToken)
bool veto = false;
if (preListeners.Length > 0)
{
var preEvent = new PreInsertEvent(Instance, Id, state, Persister, (IEventSource) Session);
var preEvent = new PreInsertEvent(Instance, Id, State, Persister, (IEventSource) Session);
foreach (IPreInsertEventListener listener in preListeners)
{
veto |= await (listener.OnPreInsertAsync(preEvent, cancellationToken)).ConfigureAwait(false);
Expand Down
20 changes: 10 additions & 10 deletions src/NHibernate/Engine/ActionQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public partial class ActionQueue
// Object insertions, updates, and deletions have list semantics because
// they must happen in the right order so as to respect referential
// integrity
private readonly List<IExecutable> insertions;
private readonly List<AbstractEntityInsertAction> insertions;
private readonly List<EntityDeleteAction> deletions;
private readonly List<EntityUpdateAction> updates;
// Actually the semantics of the next three are really "Bag"
Expand All @@ -48,7 +48,7 @@ public partial class ActionQueue
public ActionQueue(ISessionImplementor session)
{
this.session = session;
insertions = new List<IExecutable>(InitQueueListSize);
insertions = new List<AbstractEntityInsertAction>(InitQueueListSize);
deletions = new List<EntityDeleteAction>(InitQueueListSize);
updates = new List<EntityUpdateAction>(InitQueueListSize);

Expand Down Expand Up @@ -550,7 +550,7 @@ private class InsertActionSorter
private readonly Dictionary<object, int> _entityBatchDependency = new Dictionary<object, int>();

// the map of batch numbers to EntityInsertAction lists
private readonly Dictionary<int, List<EntityInsertAction>> _actionBatches = new Dictionary<int, List<EntityInsertAction>>();
private readonly Dictionary<int, List<AbstractEntityInsertAction>> _actionBatches = new Dictionary<int, List<AbstractEntityInsertAction>>();

/// <summary>
/// A sorter aiming to group inserts as much as possible for optimizing batching.
Expand All @@ -572,7 +572,7 @@ public InsertActionSorter(ActionQueue actionQueue)
public void Sort()
{
// build the map of entity names that indicate the batch number
foreach (EntityInsertAction action in _actionQueue.insertions)
foreach (var action in _actionQueue.insertions)
{
var entityName = action.EntityName;

Expand All @@ -598,7 +598,7 @@ public void Sort()
}
}

private int GetBatchNumber(EntityInsertAction action, string entityName)
private int GetBatchNumber(AbstractEntityInsertAction action, string entityName)
{
int batchNumber;
if (_latestBatches.TryGetValue(entityName, out batchNumber))
Expand All @@ -620,7 +620,7 @@ private int GetBatchNumber(EntityInsertAction action, string entityName)
return batchNumber;
}

private bool RequireNewBatch(EntityInsertAction action, int latestBatchNumberForType)
private bool RequireNewBatch(AbstractEntityInsertAction action, int latestBatchNumberForType)
{
// This method assumes the original action list is already sorted in order to respect dependencies.
var propertyValues = action.State;
Expand Down Expand Up @@ -658,20 +658,20 @@ private bool RequireNewBatch(EntityInsertAction action, int latestBatchNumberFor
return false;
}

private void AddToBatch(int batchNumber, EntityInsertAction action)
private void AddToBatch(int batchNumber, AbstractEntityInsertAction action)
{
List<EntityInsertAction> actions;
List<AbstractEntityInsertAction> actions;

if (!_actionBatches.TryGetValue(batchNumber, out actions))
{
actions = new List<EntityInsertAction>();
actions = new List<AbstractEntityInsertAction>();
_actionBatches[batchNumber] = actions;
}

actions.Add(action);
}

private void UpdateChildrenDependencies(int batchNumber, EntityInsertAction action)
private void UpdateChildrenDependencies(int batchNumber, AbstractEntityInsertAction action)
{
var propertyValues = action.State;
var propertyTypes = action.Persister.EntityMetamodel?.PropertyTypes;
Expand Down