Skip to content

Commit 94eda20

Browse files
NH-3850 - Fix for polymorphic Linq query results aggregators failures. Causes a possible breaking change of Sum: Sum will no more yields null when there was nothing to sum, but will yield 0.
1 parent 87ec4d8 commit 94eda20

11 files changed

+184
-56
lines changed

src/NHibernate/Linq/EnumerableHelper.cs

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public static MethodInfo GetMethodDefinition<TSource>(Expression<Action<TSource>
2929
public static MethodInfo GetMethod<TSource>(Expression<Action<TSource>> method)
3030
{
3131
if (method == null)
32-
throw new ArgumentNullException("method");
32+
throw new ArgumentNullException(nameof(method));
3333

3434
return ((MethodCallExpression)method.Body).Method;
3535
}
@@ -54,11 +54,40 @@ public static MethodInfo GetMethodDefinition(Expression<System.Action> method)
5454
public static MethodInfo GetMethod(Expression<System.Action> method)
5555
{
5656
if (method == null)
57-
throw new ArgumentNullException("method");
57+
throw new ArgumentNullException(nameof(method));
5858

5959
return ((MethodCallExpression)method.Body).Method;
6060
}
6161

62+
/// <summary>
63+
/// Get the <see cref="MethodInfo"/> for a public overload of a given method if the method does not match
64+
/// given parameter types, otherwise directly yield the given method.
65+
/// </summary>
66+
/// <param name="method">The method for which finding an overload.</param>
67+
/// <param name="parameterTypes">The arguments types of the overload to get.</param>
68+
/// <returns>The <see cref="MethodInfo"/> of the method.</returns>
69+
/// <remarks>Whenever possible, use GetMethod() instead for performance reasons.</remarks>
70+
public static MethodInfo GetMethodOverload(MethodInfo method, System.Type[] parameterTypes)
71+
{
72+
if (method == null)
73+
throw new ArgumentNullException(nameof(method));
74+
if (parameterTypes == null)
75+
throw new ArgumentNullException(nameof(parameterTypes));
76+
77+
if (ParameterTypesMatch(method.GetParameters(), parameterTypes))
78+
return method;
79+
80+
var overload = method.DeclaringType.GetMethod(method.Name,
81+
(method.IsStatic ? BindingFlags.Static : BindingFlags.Instance) | BindingFlags.Public,
82+
null, parameterTypes, null);
83+
84+
if (overload == null)
85+
throw new InvalidOperationException(
86+
$"No overload found for method '{method.DeclaringType.Name}.{method.Name}' and parameter types '{string.Join(", ", parameterTypes.Select(t => t.Name))}'");
87+
88+
return overload;
89+
}
90+
6291
/// <summary>
6392
/// Gets the field or property to be accessed.
6493
/// </summary>
@@ -70,7 +99,7 @@ public static MemberInfo GetProperty<TSource, TResult>(Expression<Func<TSource,
7099
{
71100
if (property == null)
72101
{
73-
throw new ArgumentNullException("property");
102+
throw new ArgumentNullException(nameof(property));
74103
}
75104
return ((MemberExpression)property.Body).Member;
76105
}
@@ -91,31 +120,8 @@ internal static System.Type GetPropertyOrFieldType(this MemberInfo memberInfo)
91120

92121
return null;
93122
}
94-
}
95-
96-
[Obsolete("Please use ReflectionHelper instead")]
97-
public static class EnumerableHelper
98-
{
99-
public static MethodInfo GetMethod(string name, System.Type[] parameterTypes)
100-
{
101-
return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
102-
.Where(m => m.Name == name &&
103-
ParameterTypesMatch(m.GetParameters(), parameterTypes))
104-
.Single();
105-
}
106-
107-
public static MethodInfo GetMethod(string name, System.Type[] parameterTypes, System.Type[] genericTypeParameters)
108-
{
109-
return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
110-
.Where(m => m.Name == name &&
111-
m.ContainsGenericParameters &&
112-
m.GetGenericArguments().Count() == genericTypeParameters.Length &&
113-
ParameterTypesMatch(m.GetParameters(), parameterTypes))
114-
.Single()
115-
.MakeGenericMethod(genericTypeParameters);
116-
}
117123

118-
private static bool ParameterTypesMatch(ParameterInfo[] parameters, System.Type[] types)
124+
internal static bool ParameterTypesMatch(ParameterInfo[] parameters, System.Type[] types)
119125
{
120126
if (parameters.Length != types.Length)
121127
{
@@ -141,4 +147,27 @@ private static bool ParameterTypesMatch(ParameterInfo[] parameters, System.Type[
141147
return true;
142148
}
143149
}
150+
151+
[Obsolete("Please use ReflectionHelper instead")]
152+
public static class EnumerableHelper
153+
{
154+
public static MethodInfo GetMethod(string name, System.Type[] parameterTypes)
155+
{
156+
return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
157+
.Where(m => m.Name == name &&
158+
ReflectionHelper.ParameterTypesMatch(m.GetParameters(), parameterTypes))
159+
.Single();
160+
}
161+
162+
public static MethodInfo GetMethod(string name, System.Type[] parameterTypes, System.Type[] genericTypeParameters)
163+
{
164+
return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
165+
.Where(m => m.Name == name &&
166+
m.ContainsGenericParameters &&
167+
m.GetGenericArguments().Count() == genericTypeParameters.Length &&
168+
ReflectionHelper.ParameterTypesMatch(m.GetParameters(), parameterTypes))
169+
.Single()
170+
.MakeGenericMethod(genericTypeParameters);
171+
}
172+
}
144173
}

src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,18 @@ public class ExpressionToHqlTranslationResults
1616
public Delegate PostExecuteTransformer { get; private set; }
1717
public List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> AdditionalCriteria { get; private set; }
1818

19+
/// <summary>
20+
/// If execute result type does not match expected final result type (implying a post execute transformer
21+
/// will yield expected result type), the intermediate execute type.
22+
/// </summary>
23+
public System.Type ExecuteResultTypeOverride { get; private set; }
24+
1925
public ExpressionToHqlTranslationResults(HqlTreeNode statement,
2026
IList<LambdaExpression> itemTransformers,
2127
IList<LambdaExpression> listTransformers,
2228
IList<LambdaExpression> postExecuteTransformers,
23-
List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> additionalCriteria)
29+
List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> additionalCriteria,
30+
System.Type executeResultTypeOverride)
2431
{
2532
Statement = statement;
2633

@@ -35,6 +42,7 @@ public ExpressionToHqlTranslationResults(HqlTreeNode statement,
3542
}
3643

3744
AdditionalCriteria = additionalCriteria;
45+
ExecuteResultTypeOverride = executeResultTypeOverride;
3846
}
3947

4048
private static TDelegate MergeLambdasAndCompile<TDelegate>(IList<LambdaExpression> itemTransformers)

src/NHibernate/Linq/IntermediateHqlTree.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ public HqlTreeNode Root
5050
}
5151
}
5252

53+
/// <summary>
54+
/// If execute result type does not match expected final result type (implying a post execute transformer
55+
/// will yield expected result type), the intermediate execute type.
56+
/// </summary>
57+
public System.Type ExecuteResultTypeOverride { get; set; }
58+
5359
public HqlTreeBuilder TreeBuilder { get; private set; }
5460

5561
public IntermediateHqlTree(bool root)
@@ -62,10 +68,11 @@ public IntermediateHqlTree(bool root)
6268
public ExpressionToHqlTranslationResults GetTranslation()
6369
{
6470
return new ExpressionToHqlTranslationResults(Root,
65-
_itemTransformers,
66-
_listTransformers,
67-
_postExecuteTransformers,
68-
_additionalCriteria);
71+
_itemTransformers,
72+
_listTransformers,
73+
_postExecuteTransformers,
74+
_additionalCriteria,
75+
ExecuteResultTypeOverride);
6976
}
7077

7178
public void AddDistinctRootOperator()

src/NHibernate/Linq/NhLinqExpression.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter
6464
var queryModel = NhRelinqQueryParser.Parse(_expression);
6565
var visitorParameters = new VisitorParameters(sessionFactory, _constantToParameterMap, requiredHqlParameters, querySourceNamer);
6666

67-
ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, visitorParameters, true);
67+
ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, visitorParameters, true, ReturnType);
68+
69+
if (ExpressionToHqlTranslationResults.ExecuteResultTypeOverride != null)
70+
Type = ExpressionToHqlTranslationResults.ExecuteResultTypeOverride;
6871

6972
ParameterDescriptors = requiredHqlParameters.AsReadOnly();
7073

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ protected HqlTreeNode VisitConditionalExpression(ConditionalExpression expressio
542542

543543
protected HqlTreeNode VisitSubQueryExpression(SubQueryExpression expression)
544544
{
545-
ExpressionToHqlTranslationResults query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameters, false);
545+
ExpressionToHqlTranslationResults query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameters, false, null);
546546
return query.Statement;
547547
}
548548

src/NHibernate/Linq/Visitors/QueryModelVisitor.cs

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
using System;
2-
using System.Collections;
3-
using System.Collections.Generic;
2+
using System.Linq;
43
using System.Linq.Expressions;
4+
using System.Reflection;
55
using NHibernate.Hql.Ast;
66
using NHibernate.Linq.Clauses;
7+
using NHibernate.Linq.Expressions;
78
using NHibernate.Linq.GroupBy;
89
using NHibernate.Linq.GroupJoin;
910
using NHibernate.Linq.NestedSelects;
1011
using NHibernate.Linq.ResultOperators;
1112
using NHibernate.Linq.ReWriters;
1213
using NHibernate.Linq.Visitors.ResultOperatorProcessors;
14+
using NHibernate.Util;
1315
using Remotion.Linq;
1416
using Remotion.Linq.Clauses;
1517
using Remotion.Linq.Clauses.ResultOperators;
@@ -20,7 +22,8 @@ namespace NHibernate.Linq.Visitors
2022
{
2123
public class QueryModelVisitor : QueryModelVisitorBase
2224
{
23-
public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root)
25+
public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root,
26+
NhLinqExpressionReturnType? rootReturnType)
2427
{
2528
NestedSelectRewriter.ReWrite(queryModel, parameters.SessionFactory);
2629

@@ -29,7 +32,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
2932

3033
// Merge aggregating result operators (distinct, count, sum etc) into the select clause
3134
MergeAggregatingResultsRewriter.ReWrite(queryModel);
32-
35+
3336
// Swap out non-aggregating group-bys
3437
NonAggregatingGroupByRewriter.ReWrite(queryModel);
3538

@@ -79,7 +82,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
7982
// Identify and name query sources
8083
QuerySourceIdentifier.Visit(parameters.QuerySourceNamer, queryModel);
8184

82-
var visitor = new QueryModelVisitor(parameters, root, queryModel)
85+
var visitor = new QueryModelVisitor(parameters, root, queryModel, rootReturnType)
8386
{
8487
RewrittenOperatorResult = result,
8588
};
@@ -89,7 +92,9 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
8992
}
9093

9194
private readonly IntermediateHqlTree _hqlTree;
95+
private readonly NhLinqExpressionReturnType? _rootReturnType;
9296
private static readonly ResultOperatorMap ResultOperatorMap;
97+
9398
private bool _serverSide = true;
9499

95100
public VisitorParameters VisitorParameters { get; private set; }
@@ -128,16 +133,64 @@ static QueryModelVisitor()
128133
ResultOperatorMap.Add<CastResultOperator, ProcessCast>();
129134
}
130135

131-
private QueryModelVisitor(VisitorParameters visitorParameters, bool root, QueryModel queryModel)
136+
private QueryModelVisitor(VisitorParameters visitorParameters, bool root, QueryModel queryModel,
137+
NhLinqExpressionReturnType? rootReturnType)
132138
{
133139
VisitorParameters = visitorParameters;
134140
Model = queryModel;
141+
_rootReturnType = root ? rootReturnType : null;
135142
_hqlTree = new IntermediateHqlTree(root);
136143
}
137144

138145
private void Visit()
139146
{
140147
VisitQueryModel(Model);
148+
149+
if (_rootReturnType == NhLinqExpressionReturnType.Scalar && Model.ResultTypeOverride != null)
150+
{
151+
// NH-3850: handle polymorphic scalar results aggregation
152+
switch ((NhExpressionType)Model.SelectClause.Selector.NodeType)
153+
{
154+
case NhExpressionType.Average:
155+
// Polymorphic case complex to handle and not implemented. (HQL query must be reshaped for adding
156+
// additional data to allow a meaningful overall average computation.)
157+
// Leaving it untouched for allowing non polymorphic cases to work.
158+
break;
159+
case NhExpressionType.Max:
160+
AddPostExecuteTransformerForResultAggregate(ReflectionCache.QueryableMethods.MaxDefinition);
161+
break;
162+
case NhExpressionType.Min:
163+
AddPostExecuteTransformerForResultAggregate(ReflectionCache.QueryableMethods.MinDefinition);
164+
break;
165+
case NhExpressionType.Count:
166+
// Count results have to be summed.
167+
case NhExpressionType.Sum:
168+
// Sum has no suitable generic overload, throw in one Sum candidate, then the code using it
169+
// will check and adjust it if needed (AddPostExecuteTransformerForResultAggregate does that).
170+
AddPostExecuteTransformerForResultAggregate(ReflectionCache.QueryableMethods.SumOnInt);
171+
break;
172+
}
173+
}
174+
}
175+
176+
private void AddPostExecuteTransformerForResultAggregate(MethodInfo aggregateMethodTemplate)
177+
{
178+
var resultType = Model.ResultTypeOverride;
179+
var inputListType = typeof(IQueryable<>).MakeGenericType(resultType);
180+
MethodInfo aggregateMethod;
181+
if (aggregateMethodTemplate.IsGenericMethodDefinition)
182+
{
183+
aggregateMethod = aggregateMethodTemplate.MakeGenericMethod(new[] { resultType });
184+
}
185+
else
186+
{
187+
// Ensure we use the right overload.
188+
aggregateMethod = ReflectionHelper.GetMethodOverload(aggregateMethodTemplate, new[] { inputListType });
189+
}
190+
191+
var inputList = Expression.Parameter(inputListType, "inputList");
192+
var aggregateCall = Expression.Call(aggregateMethod, inputList);
193+
_hqlTree.AddPostExecuteTransformer(Expression.Lambda(aggregateCall, inputList));
141194
}
142195

143196
public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel)

src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,22 @@ public void Process(AggregateResultOperator resultOperator, QueryModelVisitor qu
1919
resultOperator.Func.Parameters[0],
2020
paramExpr);
2121

22-
var inputList = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(typeof(object)), "inputList");
23-
24-
var castToItem = ReflectionCache.EnumerableMethods.CastDefinition.MakeGenericMethod(new[] { inputType });
25-
var castToItemExpr = Expression.Call(castToItem, inputList);
26-
22+
// NH-3850: changed from list transformer (working on IEnumerable<object>) to post execute
23+
// transformer (working on IEnumerable<inputType>) for globally aggregating polymorphic results
24+
// instead of aggregating results for each class separately and yielding only the first.
25+
// If the aggregation relies on ordering, final result will still be wrong due to
26+
// polymorphic results being union-ed without re-ordering. (This is a limitation of all polymorphic
27+
// queries, this is not specific to LINQ provider.)
28+
var inputList = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(inputType), "inputList");
2729
var aggregate = ReflectionCache.EnumerableMethods.AggregateDefinition.MakeGenericMethod(inputType);
28-
2930
MethodCallExpression call = Expression.Call(
3031
aggregate,
31-
castToItemExpr,
32+
inputList,
3233
accumulatorFunc
3334
);
34-
35-
tree.AddListTransformer(Expression.Lambda(call, inputList));
35+
tree.AddPostExecuteTransformer(Expression.Lambda(call, inputList));
36+
// There is no more a list transformer yielding an IList<resultType>, but this aggregate case
37+
// have inputType = resultType, so no further action is required.
3638
}
3739
}
3840
}

0 commit comments

Comments
 (0)