Skip to content

Fix join on subclass columns #2680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/NHibernate.Test/Async/NHSpecificTest/NH1747/Fixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
//------------------------------------------------------------------------------


using System.Linq;
using NHibernate.Linq;
using NUnit.Framework;

namespace NHibernate.Test.NHSpecificTest.NH1747
Expand Down Expand Up @@ -51,5 +53,25 @@ public async Task TraversingBagToJoinChildElementShouldWorkAsync()
Assert.AreEqual(1, paymentBatch.Payments.Count);
}
}

[Test]
public async Task TraversingBagToJoinChildElementShouldWorkLinqFetchAsync()
{
using (ISession session = OpenSession())
{
var paymentBatch = await (session.Query<PaymentBatch>().Fetch(x => x.Payments).SingleOrDefaultAsync());
Assert.AreEqual(1, paymentBatch.Payments.Count);
}
}

[Test]
public async Task TraversingBagToJoinChildElementShouldWorkQueryOverFetchAsync()
{
using (ISession session = OpenSession())
{
var paymentBatch = await (session.Query<PaymentBatch>().Fetch(x => x.Payments).SingleOrDefaultAsync());
Assert.AreEqual(1, paymentBatch.Payments.Count);
}
}
}
}
2 changes: 0 additions & 2 deletions src/NHibernate.Test/Async/NHSpecificTest/NH2174/Fixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ protected override void OnTearDown()
}
}

[KnownBug("Not fixed yet")]
[Test]
public async Task LinqFetchAsync()
{
Expand All @@ -54,7 +53,6 @@ public async Task LinqFetchAsync()
}
}

[KnownBug("Not fixed yet")]
[Test]
public async Task QueryOverFetchAsync()
{
Expand Down
24 changes: 23 additions & 1 deletion src/NHibernate.Test/NHSpecificTest/NH1747/Fixture.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using NUnit.Framework;
using System.Linq;
using NHibernate.Linq;
using NUnit.Framework;

namespace NHibernate.Test.NHSpecificTest.NH1747
{
Expand Down Expand Up @@ -52,5 +54,25 @@ public void TraversingBagToJoinChildElementShouldWork()
Assert.AreEqual(1, paymentBatch.Payments.Count);
}
}

[Test]
public void TraversingBagToJoinChildElementShouldWorkLinqFetch()
{
using (ISession session = OpenSession())
{
var paymentBatch = session.Query<PaymentBatch>().Fetch(x => x.Payments).SingleOrDefault();
Assert.AreEqual(1, paymentBatch.Payments.Count);
}
}

[Test]
public void TraversingBagToJoinChildElementShouldWorkQueryOverFetch()
{
using (ISession session = OpenSession())
{
var paymentBatch = session.Query<PaymentBatch>().Fetch(x => x.Payments).SingleOrDefault();
Assert.AreEqual(1, paymentBatch.Payments.Count);
}
}
}
}
2 changes: 0 additions & 2 deletions src/NHibernate.Test/NHSpecificTest/NH2174/Fixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ protected override void OnTearDown()
}
}

[KnownBug("Not fixed yet")]
[Test]
public void LinqFetch()
{
Expand All @@ -43,7 +42,6 @@ public void LinqFetch()
}
}

[KnownBug("Not fixed yet")]
[Test]
public void QueryOverFetch()
{
Expand Down
1 change: 1 addition & 0 deletions src/NHibernate/Engine/IJoin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ internal interface IJoin
{
IJoinable Joinable { get; }
string[] LHSColumns { get; }
string[] RHSColumns { get; }
string Alias { get; }
IAssociationType AssociationType { get; }
JoinType JoinType { get; }
Expand Down
21 changes: 12 additions & 9 deletions src/NHibernate/Engine/JoinHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@ public static ILhsAssociationTypeSqlInfo GetIdLhsSqlInfo(string alias, IOuterJoi
/// be used in the join
/// </summary>
public static string[] GetRHSColumnNames(IAssociationType type, ISessionFactoryImplementor factory)
{
return GetRHSColumnNames(type.GetAssociatedJoinable(factory), type);
}

/// <summary>
/// Get the columns of the associated table which are to
/// be used in the join
/// </summary>
public static string[] GetRHSColumnNames(IJoinable joinable, IAssociationType type)
{
string uniqueKeyPropertyName = type.RHSUniqueKeyPropertyName;
IJoinable joinable = type.GetAssociatedJoinable(factory);
if (uniqueKeyPropertyName == null)
{
return joinable.KeyColumnNames;
}
else
{
return ((IOuterJoinLoadable)joinable).GetPropertyColumnNames(uniqueKeyPropertyName);
}
return uniqueKeyPropertyName == null
? joinable.KeyColumnNames
: ((IOuterJoinLoadable) joinable).GetPropertyColumnNames(uniqueKeyPropertyName);
}
}

Expand Down
11 changes: 10 additions & 1 deletion src/NHibernate/Engine/JoinSequence.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ private sealed class Join : IJoin
private readonly JoinType joinType;
private readonly string alias;
private readonly string[] lhsColumns;
private readonly string[] rhsColumns;

public Join(ISessionFactoryImplementor factory, IAssociationType associationType, string alias, JoinType joinType,
string[] lhsColumns)
Expand All @@ -56,6 +57,9 @@ public Join(ISessionFactoryImplementor factory, IAssociationType associationType
this.alias = alias;
this.joinType = joinType;
this.lhsColumns = lhsColumns;
this.rhsColumns = lhsColumns.Length > 0
? JoinHelper.GetRHSColumnNames(joinable, associationType)
: Array.Empty<string>();
}

public string Alias
Expand Down Expand Up @@ -83,6 +87,11 @@ public string[] LHSColumns
get { return lhsColumns; }
}

public string[] RHSColumns
{
get { return rhsColumns; }
}

public override string ToString()
{
return joinable.ToString() + '[' + alias + ']';
Expand Down Expand Up @@ -195,7 +204,7 @@ internal JoinFragment ToJoinFragment(
join.Joinable.TableName,
join.Alias,
join.LHSColumns,
JoinHelper.GetRHSColumnNames(join.AssociationType, factory),
join.RHSColumns,
join.JoinType,
withClauses[i]
);
Expand Down
59 changes: 40 additions & 19 deletions src/NHibernate/Engine/TableGroupJoinHelper.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using NHibernate.Persister.Collection;
using NHibernate.Persister.Entity;
using NHibernate.SqlCommand;

Expand Down Expand Up @@ -38,7 +39,7 @@ internal static bool ProcessAsTableGroupJoin(IReadOnlyList<IJoin> tableGroupJoin
join.Joinable.TableName,
join.Alias,
join.LHSColumns,
JoinHelper.GetRHSColumnNames(join.AssociationType, sessionFactoryImplementor),
join.RHSColumns,
join.JoinType,
SqlString.Empty);

Expand All @@ -51,32 +52,37 @@ internal static bool ProcessAsTableGroupJoin(IReadOnlyList<IJoin> tableGroupJoin
join.Joinable.WhereJoinFragment(join.Alias, innerJoin, include));
}

var withClause = GetTableGroupJoinWithClause(withClauseFragments, first, sessionFactoryImplementor);
var withClause = GetTableGroupJoinWithClause(withClauseFragments, first);
joinFragment.AddFromFragmentString(withClause);
return true;
}

// detect cases when withClause is used on multiple tables or when join keys depend on subclass columns
private static bool NeedsTableGroupJoin(IReadOnlyList<IJoin> joins, SqlString[] withClauseFragments, bool includeSubclasses)
{
// If we don't have a with clause, we don't need a table group join
if (withClauseFragments.All(x => SqlStringHelper.IsEmpty(x)))
{
return false;
}
bool hasWithClause = withClauseFragments.Any(x => SqlStringHelper.IsNotEmpty(x));

// If we only have one join, a table group join is only necessary if subclass columns are used in the with clause
if (joins.Count == 1)
//NH Specific: No alias processing (see hibernate JoinSequence.NeedsTableGroupJoin)
if (joins.Count > 1 && hasWithClause)
return true;

foreach (var join in joins)
{
return joins[0].Joinable is AbstractEntityPersister persister && persister.HasSubclassJoins(includeSubclasses);
//NH Specific: No alias processing
//return isSubclassAliasDereferenced( joins[ 0], withClauseFragment );
var entityPersister = GetEntityPersister(join.Joinable);
if (entityPersister?.HasSubclassJoins(includeSubclasses) != true)
continue;

if (hasWithClause)
return true;

if (entityPersister.ColumnsDependOnSubclassJoins(join.RHSColumns))
return true;
}

//NH Specific: No alias processing (see hibernate JoinSequence.NeedsTableGroupJoin)
return true;
return false;
}

private static SqlString GetTableGroupJoinWithClause(SqlString[] withClauseFragments, IJoin first, ISessionFactoryImplementor factory)
private static SqlString GetTableGroupJoinWithClause(SqlString[] withClauseFragments, IJoin first)
{
SqlStringBuilder fromFragment = new SqlStringBuilder();
fromFragment.Add(")").Add(" on ");
Expand All @@ -85,12 +91,18 @@ private static SqlString GetTableGroupJoinWithClause(SqlString[] withClauseFragm
var isAssociationJoin = lhsColumns.Length > 0;
if (isAssociationJoin)
{
var entityPersister = GetEntityPersister(first.Joinable);
string rhsAlias = first.Alias;
string[] rhsColumns = JoinHelper.GetRHSColumnNames(first.AssociationType, factory);
fromFragment.Add(lhsColumns[0]).Add("=").Add(rhsAlias).Add(".").Add(rhsColumns[0]);
for (int j = 1; j < lhsColumns.Length; j++)
string[] rhsColumns = first.RHSColumns;
for (int j = 0; j < lhsColumns.Length; j++)
{
fromFragment.Add(" and ").Add(lhsColumns[j]).Add("=").Add(rhsAlias).Add(".").Add(rhsColumns[j]);
fromFragment.Add(lhsColumns[j])
.Add("=")
.Add(entityPersister?.GenerateTableAliasForColumn(rhsAlias, rhsColumns[j]) ?? rhsAlias)
.Add(".")
.Add(rhsColumns[j]);
if (j != lhsColumns.Length - 1)
fromFragment.Add(" and ");
}
}

Expand All @@ -99,6 +111,15 @@ private static SqlString GetTableGroupJoinWithClause(SqlString[] withClauseFragm
return fromFragment.ToSqlString();
}

private static AbstractEntityPersister GetEntityPersister(IJoinable joinable)
{
if (!joinable.IsCollection)
return joinable as AbstractEntityPersister;

var collection = (IQueryableCollection) joinable;
return collection.ElementType.IsEntityType ? collection.ElementPersister as AbstractEntityPersister : null;
}

private static void AppendWithClause(SqlStringBuilder fromFragment, bool hasConditions, SqlString[] withClauseFragments)
{
for (var i = 0; i < withClauseFragments.Length; i++)
Expand Down
5 changes: 3 additions & 2 deletions src/NHibernate/Loader/JoinWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,9 @@ protected virtual bool IsDuplicateAssociation(string lhsTable, string[] lhsColum
}
else
{
foreignKeyTable = type.GetAssociatedJoinable(Factory).TableName;
foreignKeyColumns = JoinHelper.GetRHSColumnNames(type, Factory);
var joinable = type.GetAssociatedJoinable(Factory);
foreignKeyTable = joinable.TableName;
foreignKeyColumns = JoinHelper.GetRHSColumnNames(joinable, type);
}

return IsDuplicateAssociation(foreignKeyTable, foreignKeyColumns);
Expand Down
3 changes: 2 additions & 1 deletion src/NHibernate/Loader/OuterJoinableAssociation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public OuterJoinableAssociation(IAssociationType joinableType, String lhsAlias,
this.rhsAlias = rhsAlias;
this.joinType = joinType;
joinable = joinableType.GetAssociatedJoinable(factory);
rhsColumns = JoinHelper.GetRHSColumnNames(joinableType, factory);
rhsColumns = JoinHelper.GetRHSColumnNames(joinable, joinableType);
on = new SqlString(joinableType.GetOnCondition(rhsAlias, factory, enabledFilters));
if (SqlStringHelper.IsNotEmpty(withClause))
on = on.Append(" and ( ", withClause, " )");
Expand Down Expand Up @@ -105,6 +105,7 @@ public SelectMode SelectMode
string[] IJoin.LHSColumns => lhsColumns;
string IJoin.Alias => RHSAlias;
IAssociationType IJoin.AssociationType => JoinableType;
string[] IJoin.RHSColumns => rhsColumns;

public int GetOwner(IList<OuterJoinableAssociation> associations)
{
Expand Down
28 changes: 22 additions & 6 deletions src/NHibernate/Persister/Entity/AbstractEntityPersister.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2138,14 +2138,20 @@ public virtual Declarer GetSubclassPropertyDeclarer(string propertyPath)

public virtual string GenerateTableAliasForColumn(string rootAlias, string column)
{
int propertyIndex = Array.IndexOf(SubclassColumnClosure, column);
return GenerateTableAlias(rootAlias, GetColumnTableNumber(column));
}

private int GetColumnTableNumber(string column)
{
if (SubclassTableSpan == 1)
return 0;

int i = Array.IndexOf(SubclassColumnClosure, column);

// The check for KeyColumnNames was added to fix NH-2491
if (propertyIndex < 0 || Array.IndexOf(KeyColumnNames, column) >= 0)
{
return rootAlias;
}
return GenerateTableAlias(rootAlias, SubclassColumnTableNumberClosure[propertyIndex]);
return i < 0 || Array.IndexOf(KeyColumnNames, column) >= 0
? 0
: SubclassColumnTableNumberClosure[i];
}

public string GenerateTableAlias(string rootAlias, int tableNumber)
Expand Down Expand Up @@ -3796,6 +3802,16 @@ private JoinFragment CreateJoin(string name, bool innerjoin, bool includeSubclas
return join;
}

internal bool ColumnsDependOnSubclassJoins(string[] columns)
{
foreach (var column in columns)
{
if (GetColumnTableNumber(column) > 0)
return true;
}
return false;
}

internal bool HasSubclassJoins(bool includeSubclasses)
{
if (SubclassTableSpan == 1)
Expand Down