Skip to content

Commit 98ea1c9

Browse files
NH-3850 - Fix for polymorphic Linq query results aggregators failures.
1 parent 2006aee commit 98ea1c9

12 files changed

+277
-61
lines changed

src/NHibernate/Linq/EnumerableHelper.cs

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public static MethodInfo GetMethodDefinition<TSource>(Expression<Action<TSource>
2929
public static MethodInfo GetMethod<TSource>(Expression<Action<TSource>> method)
3030
{
3131
if (method == null)
32-
throw new ArgumentNullException("method");
32+
throw new ArgumentNullException(nameof(method));
3333

3434
return ((MethodCallExpression)method.Body).Method;
3535
}
@@ -54,11 +54,40 @@ public static MethodInfo GetMethodDefinition(Expression<System.Action> method)
5454
public static MethodInfo GetMethod(Expression<System.Action> method)
5555
{
5656
if (method == null)
57-
throw new ArgumentNullException("method");
57+
throw new ArgumentNullException(nameof(method));
5858

5959
return ((MethodCallExpression)method.Body).Method;
6060
}
6161

62+
/// <summary>
63+
/// Get the <see cref="MethodInfo"/> for a public overload of a given method if the method does not match
64+
/// given parameter types, otherwise directly yield the given method.
65+
/// </summary>
66+
/// <param name="method">The method for which finding an overload.</param>
67+
/// <param name="parameterTypes">The arguments types of the overload to get.</param>
68+
/// <returns>The <see cref="MethodInfo"/> of the method.</returns>
69+
/// <remarks>Whenever possible, use GetMethod() instead for performance reasons.</remarks>
70+
public static MethodInfo GetMethodOverload(MethodInfo method, System.Type[] parameterTypes)
71+
{
72+
if (method == null)
73+
throw new ArgumentNullException(nameof(method));
74+
if (parameterTypes == null)
75+
throw new ArgumentNullException(nameof(parameterTypes));
76+
77+
if (ParameterTypesMatch(method.GetParameters(), parameterTypes))
78+
return method;
79+
80+
var overload = method.DeclaringType.GetMethod(method.Name,
81+
(method.IsStatic ? BindingFlags.Static : BindingFlags.Instance) | BindingFlags.Public,
82+
null, parameterTypes, null);
83+
84+
if (overload == null)
85+
throw new InvalidOperationException(
86+
$"No overload found for method '{method.DeclaringType.Name}.{method.Name}' and parameter types '{string.Join(", ", parameterTypes.Select(t => t.Name))}'");
87+
88+
return overload;
89+
}
90+
6291
/// <summary>
6392
/// Gets the field or property to be accessed.
6493
/// </summary>
@@ -70,7 +99,7 @@ public static MemberInfo GetProperty<TSource, TResult>(Expression<Func<TSource,
7099
{
71100
if (property == null)
72101
{
73-
throw new ArgumentNullException("property");
102+
throw new ArgumentNullException(nameof(property));
74103
}
75104
return ((MemberExpression)property.Body).Member;
76105
}
@@ -91,31 +120,8 @@ internal static System.Type GetPropertyOrFieldType(this MemberInfo memberInfo)
91120

92121
return null;
93122
}
94-
}
95-
96-
[Obsolete("Please use ReflectionHelper instead")]
97-
public static class EnumerableHelper
98-
{
99-
public static MethodInfo GetMethod(string name, System.Type[] parameterTypes)
100-
{
101-
return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
102-
.Where(m => m.Name == name &&
103-
ParameterTypesMatch(m.GetParameters(), parameterTypes))
104-
.Single();
105-
}
106-
107-
public static MethodInfo GetMethod(string name, System.Type[] parameterTypes, System.Type[] genericTypeParameters)
108-
{
109-
return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
110-
.Where(m => m.Name == name &&
111-
m.ContainsGenericParameters &&
112-
m.GetGenericArguments().Count() == genericTypeParameters.Length &&
113-
ParameterTypesMatch(m.GetParameters(), parameterTypes))
114-
.Single()
115-
.MakeGenericMethod(genericTypeParameters);
116-
}
117123

118-
private static bool ParameterTypesMatch(ParameterInfo[] parameters, System.Type[] types)
124+
internal static bool ParameterTypesMatch(ParameterInfo[] parameters, System.Type[] types)
119125
{
120126
if (parameters.Length != types.Length)
121127
{
@@ -141,4 +147,27 @@ private static bool ParameterTypesMatch(ParameterInfo[] parameters, System.Type[
141147
return true;
142148
}
143149
}
150+
151+
[Obsolete("Please use ReflectionHelper instead")]
152+
public static class EnumerableHelper
153+
{
154+
public static MethodInfo GetMethod(string name, System.Type[] parameterTypes)
155+
{
156+
return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
157+
.Where(m => m.Name == name &&
158+
ReflectionHelper.ParameterTypesMatch(m.GetParameters(), parameterTypes))
159+
.Single();
160+
}
161+
162+
public static MethodInfo GetMethod(string name, System.Type[] parameterTypes, System.Type[] genericTypeParameters)
163+
{
164+
return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
165+
.Where(m => m.Name == name &&
166+
m.ContainsGenericParameters &&
167+
m.GetGenericArguments().Count() == genericTypeParameters.Length &&
168+
ReflectionHelper.ParameterTypesMatch(m.GetParameters(), parameterTypes))
169+
.Single()
170+
.MakeGenericMethod(genericTypeParameters);
171+
}
172+
}
144173
}

src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,18 @@ public class ExpressionToHqlTranslationResults
1616
public Delegate PostExecuteTransformer { get; private set; }
1717
public List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> AdditionalCriteria { get; private set; }
1818

19+
/// <summary>
20+
/// If execute result type does not match expected final result type (implying a post execute transformer
21+
/// will yield expected result type), the intermediate execute type.
22+
/// </summary>
23+
public System.Type ExecuteResultTypeOverride { get; private set; }
24+
1925
public ExpressionToHqlTranslationResults(HqlTreeNode statement,
2026
IList<LambdaExpression> itemTransformers,
2127
IList<LambdaExpression> listTransformers,
2228
IList<LambdaExpression> postExecuteTransformers,
23-
List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> additionalCriteria)
29+
List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> additionalCriteria,
30+
System.Type executeResultTypeOverride)
2431
{
2532
Statement = statement;
2633

@@ -35,6 +42,7 @@ public ExpressionToHqlTranslationResults(HqlTreeNode statement,
3542
}
3643

3744
AdditionalCriteria = additionalCriteria;
45+
ExecuteResultTypeOverride = executeResultTypeOverride;
3846
}
3947

4048
private static TDelegate MergeLambdasAndCompile<TDelegate>(IList<LambdaExpression> itemTransformers)

src/NHibernate/Linq/Functions/QueryableGenerator.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ public class AllHqlGenerator : BaseHqlGeneratorForMethod
5353
public AllHqlGenerator()
5454
{
5555
SupportedMethods = new[]
56-
{
57-
ReflectionHelper.GetMethodDefinition(() => Queryable.All<object>(null, null)),
58-
ReflectionHelper.GetMethodDefinition(() => Enumerable.All<object>(null, null))
59-
};
56+
{
57+
ReflectionHelper.GetMethodDefinition(() => Queryable.All<object>(null, null)),
58+
ReflectionCache.EnumerableMethods.AllDefinition
59+
};
6060
}
6161

6262
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)

src/NHibernate/Linq/IntermediateHqlTree.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ public HqlTreeNode Root
5050
}
5151
}
5252

53+
/// <summary>
54+
/// If execute result type does not match expected final result type (implying a post execute transformer
55+
/// will yield expected result type), the intermediate execute type.
56+
/// </summary>
57+
public System.Type ExecuteResultTypeOverride { get; set; }
58+
5359
public HqlTreeBuilder TreeBuilder { get; private set; }
5460

5561
public IntermediateHqlTree(bool root)
@@ -62,10 +68,11 @@ public IntermediateHqlTree(bool root)
6268
public ExpressionToHqlTranslationResults GetTranslation()
6369
{
6470
return new ExpressionToHqlTranslationResults(Root,
65-
_itemTransformers,
66-
_listTransformers,
67-
_postExecuteTransformers,
68-
_additionalCriteria);
71+
_itemTransformers,
72+
_listTransformers,
73+
_postExecuteTransformers,
74+
_additionalCriteria,
75+
ExecuteResultTypeOverride);
6976
}
7077

7178
public void AddDistinctRootOperator()

src/NHibernate/Linq/NhLinqExpression.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter
6464
var queryModel = NhRelinqQueryParser.Parse(_expression);
6565
var visitorParameters = new VisitorParameters(sessionFactory, _constantToParameterMap, requiredHqlParameters, querySourceNamer);
6666

67-
ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, visitorParameters, true);
67+
ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, visitorParameters, true, ReturnType);
68+
69+
if (ExpressionToHqlTranslationResults.ExecuteResultTypeOverride != null)
70+
Type = ExpressionToHqlTranslationResults.ExecuteResultTypeOverride;
6871

6972
ParameterDescriptors = requiredHqlParameters.AsReadOnly();
7073

@@ -75,6 +78,8 @@ internal void CopyExpressionTranslation(NhLinqExpression other)
7578
{
7679
ExpressionToHqlTranslationResults = other.ExpressionToHqlTranslationResults;
7780
ParameterDescriptors = other.ParameterDescriptors;
81+
// Type could have been overridden by translation.
82+
Type = other.Type;
7883
}
7984
}
8085
}

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ protected HqlTreeNode VisitConditionalExpression(ConditionalExpression expressio
542542

543543
protected HqlTreeNode VisitSubQueryExpression(SubQueryExpression expression)
544544
{
545-
ExpressionToHqlTranslationResults query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameters, false);
545+
ExpressionToHqlTranslationResults query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameters, false, null);
546546
return query.Statement;
547547
}
548548

0 commit comments

Comments
 (0)