Skip to content

Commit 6a31261

Browse files
committed
Fix regressions introduced by Conditional and Coalesce expansion (nhibernate#1880)
1 parent fa978cf commit 6a31261

File tree

4 files changed

+128
-24
lines changed

4 files changed

+128
-24
lines changed

src/NHibernate.Test/NHSpecificTest/GH1879/Entity.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,8 @@ public class Invoice
5858
public virtual int InvoiceNumber { get; set; }
5959
public virtual Project Project { get; set; }
6060
public virtual Issue Issue { get; set; }
61+
public virtual int Amount { get; set; }
62+
public virtual int? SpecialAmount { get; set; }
63+
public virtual bool Paid { get; set; }
6164
}
6265
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Collections.ObjectModel;
4+
using System.Linq;
5+
using System.Linq.Expressions;
6+
using System.Reflection;
7+
using System.Text;
8+
using System.Threading.Tasks;
9+
using NHibernate.Cfg;
10+
using NHibernate.DomainModel.Northwind.Entities;
11+
using NHibernate.Hql.Ast;
12+
using NHibernate.Linq.Functions;
13+
using NHibernate.Linq.Visitors;
14+
using NHibernate.Util;
15+
using NUnit.Framework;
16+
17+
namespace NHibernate.Test.NHSpecificTest.GH1879
18+
{
19+
[TestFixture]
20+
public class ExpansionRegressionTests : GH1879BaseFixture<Invoice>
21+
{
22+
protected override void OnSetUp()
23+
{
24+
using (var session = OpenSession())
25+
using (var transaction = session.BeginTransaction())
26+
{
27+
session.Save(new Invoice { InvoiceNumber = 1, Amount = 10, SpecialAmount = 100, Paid = false });
28+
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 100, Paid = true });
29+
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = false });
30+
session.Save(new Invoice { InvoiceNumber = 2, Amount = 10, SpecialAmount = 110, Paid = true });
31+
32+
session.Flush();
33+
transaction.Commit();
34+
}
35+
}
36+
37+
protected override void Configure(Configuration configuration)
38+
{
39+
configuration.LinqToHqlGeneratorsRegistry<TestLinqToHqlGeneratorsRegistry>();
40+
}
41+
42+
private class TestLinqToHqlGeneratorsRegistry : DefaultLinqToHqlGeneratorsRegistry
43+
{
44+
public TestLinqToHqlGeneratorsRegistry()
45+
{
46+
this.Merge(new ObjectEquality());
47+
}
48+
}
49+
50+
private class ObjectEquality : IHqlGeneratorForMethod
51+
{
52+
public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
53+
{
54+
return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(), visitor.Visit(arguments[0]).AsExpression());
55+
}
56+
57+
public IEnumerable<MethodInfo> SupportedMethods
58+
{
59+
get
60+
{
61+
yield return ReflectHelper.GetMethodDefinition<object>(x => x.Equals(x));
62+
}
63+
}
64+
}
65+
66+
[Test]
67+
public void MethodShouldNotExpandForNonConditionalOrCoalesce()
68+
{
69+
using (var session = OpenSession())
70+
{
71+
Assert.That(session.Query<Invoice>().Count(e => ((object)(e.Amount + e.SpecialAmount)).Equals(110)), Is.EqualTo(2));
72+
}
73+
}
74+
75+
[Test]
76+
public void MethodShouldNotExpandForConditionalWithPropertyAccessor()
77+
{
78+
using (var session = OpenSession())
79+
{
80+
Assert.That(session.Query<Invoice>().Count(e => ((object)(e.Paid ? e.Amount : e.SpecialAmount)).Equals(10)), Is.EqualTo(2));
81+
}
82+
}
83+
84+
[Test]
85+
public void MethodShouldNotExpandForCoalesceWithPropertyAccessor()
86+
{
87+
using (var session = OpenSession())
88+
{
89+
Assert.That(session.Query<Invoice>().Count(e => ((object)(e.SpecialAmount ?? e.Amount)).Equals(100)), Is.EqualTo(2));
90+
}
91+
}
92+
}
93+
}

src/NHibernate.Test/NHSpecificTest/GH1879/FixtureByCode.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ protected override HbmMapping GetMappings()
4949
{
5050
rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb));
5151
rc.Property(x => x.InvoiceNumber);
52+
rc.Property(x => x.Amount);
53+
rc.Property(x => x.SpecialAmount);
54+
rc.Property(x => x.Paid);
5255
rc.ManyToOne(x => x.Project, m => m.Column("ProjectId"));
5356
rc.ManyToOne(x => x.Issue, m => m.Column("IssueId"));
5457
});
@@ -108,9 +111,9 @@ protected void AreEqual<TResult>(
108111
{
109112
expectedResult = expectedQuery(session.Query<T>()).ToList();
110113
}
111-
catch
114+
catch (Exception e)
112115
{
113-
Assert.Ignore("Not currently supported query");
116+
Assert.Ignore($"Not currently supported query: {e}");
114117
}
115118

116119
var testResult = actualQuery(session.Query<T>()).ToList();

src/NHibernate/Linq/ReWriters/ConditionalQueryReferenceExpander.cs

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public void Transform(ResultOperatorBase resultOperator)
7777
protected override Expression VisitMember(MemberExpression node)
7878
{
7979
var result = (MemberExpression) base.VisitMember(node);
80-
if (QueryReferenceCounter.CountReferences(result.Expression) > 1)
80+
if (ShouldRewrite(result.Expression))
8181
{
8282
return ConditionalQueryReferenceMemberExpressionRewriter.Rewrite(result.Expression, node);
8383
}
@@ -90,39 +90,44 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
9090
var isExtension = node.Method.GetCustomAttributes<ExtensionAttribute>().Any();
9191
var methodObject = isExtension ? node.Arguments[0] : node.Object;
9292

93-
if (methodObject != null && QueryReferenceCounter.CountReferences(methodObject) > 1)
93+
if (ShouldRewrite(methodObject))
9494
{
9595
return ConditionalQueryReferenceMethodCallExpressionRewriter.Rewrite(methodObject, node);
9696
}
9797
return result;
9898
}
99-
}
100-
101-
private class QueryReferenceCounter : RelinqExpressionVisitor
102-
{
103-
private readonly System.Type _queryType;
104-
private int _queryReferenceCount;
10599

106-
private QueryReferenceCounter(System.Type queryType)
100+
private bool ShouldRewrite(Expression expr, System.Type queryType = null)
107101
{
108-
_queryType = queryType;
109-
}
102+
if (expr == null)
103+
{
104+
return false;
105+
}
106+
107+
// Strip Converts
108+
while (expr.NodeType == ExpressionType.Convert || expr.NodeType == ExpressionType.ConvertChecked)
109+
{
110+
expr = ((UnaryExpression)expr).Operand;
111+
}
110112

111-
protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression)
112-
{
113-
if (_queryType.IsAssignableFrom(expression.Type))
113+
if (expr is QuerySourceReferenceExpression && queryType?.IsAssignableFrom(expr.Type) == true)
114114
{
115-
_queryReferenceCount++;
115+
return true;
116116
}
117117

118-
return base.VisitQuerySourceReference(expression);
119-
}
118+
queryType = queryType ?? expr.Type;
120119

121-
public static int CountReferences(Expression node)
122-
{
123-
var visitor = new QueryReferenceCounter(node.Type);
124-
visitor.Visit(node);
125-
return visitor._queryReferenceCount;
120+
if (expr.NodeType == ExpressionType.Coalesce && expr is BinaryExpression coalesce)
121+
{
122+
return ShouldRewrite(coalesce.Left, queryType) && ShouldRewrite(coalesce.Right, queryType);
123+
}
124+
125+
if (expr.NodeType == ExpressionType.Conditional && expr is ConditionalExpression conditional)
126+
{
127+
return ShouldRewrite(conditional.IfFalse, queryType) && ShouldRewrite(conditional.IfTrue, queryType);
128+
}
129+
130+
return false;
126131
}
127132
}
128133

0 commit comments

Comments
 (0)