Skip to content

Commit b1971a1

Browse files
committed
Fix detecting parameter type for Contains method for Linq provider
1 parent 2011a77 commit b1971a1

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed

src/NHibernate.Test/Async/Linq/EnumTests.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ public async Task CanQueryOnEnumStoredAsString_Small_1Async()
6262
Assert.AreEqual(expectedCount, query.Count);
6363
}
6464

65+
[Test]
66+
public async Task CanQueryWithContainsOnEnumStoredAsString_Small_1Async()
67+
{
68+
var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium };
69+
var query = await (db.Users.Where(x => values.Contains(x.Enum1)).ToListAsync());
70+
Assert.AreEqual(3, query.Count);
71+
}
72+
6573
[Test]
6674
public async Task ConditionalNavigationPropertyAsync()
6775
{

src/NHibernate.Test/Linq/EnumTests.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ public void CanQueryOnEnumStoredAsString(EnumStoredAsString type, int expectedCo
4949
Assert.AreEqual(expectedCount, query.Count);
5050
}
5151

52+
[Test]
53+
public void CanQueryWithContainsOnEnumStoredAsString_Small_1()
54+
{
55+
var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium };
56+
var query = db.Users.Where(x => values.Contains(x.Enum1)).ToList();
57+
Assert.AreEqual(3, query.Count);
58+
}
59+
5260
[Test]
5361
public void ConditionalNavigationProperty()
5462
{

src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ public void EqualStringEnumTest()
8484
);
8585
}
8686

87+
[Test]
88+
public void ContainsStringEnumTest()
89+
{
90+
var values = new[] {EnumStoredAsString.Small};
91+
AssertResults(
92+
new Dictionary<string, Predicate<IType>>
93+
{
94+
{"value(NHibernate.DomainModel.Northwind.Entities.EnumStoredAsString[])", o => o is EnumStoredAsStringType}
95+
},
96+
db.Users.Where(o => values.Contains(o.Enum1)),
97+
db.Users.Where(o => values.Contains(o.NullableEnum1.Value)),
98+
db.Users.Where(o => values.Contains(o.Name == o.Name ? o.Enum1 : o.NullableEnum1.Value)),
99+
db.Timesheets.Where(o => o.Users.Any(u => values.Contains(u.Enum1)))
100+
);
101+
}
102+
87103
[Test]
88104
public void EqualStringEnumTestWithFetch()
89105
{

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
using System.Collections.Generic;
22
using System.Dynamic;
3+
using System.Linq;
34
using System.Linq.Expressions;
45
using NHibernate.Engine;
56
using NHibernate.Param;
67
using NHibernate.Type;
78
using NHibernate.Util;
89
using Remotion.Linq;
10+
using Remotion.Linq.Clauses;
911
using Remotion.Linq.Clauses.Expressions;
12+
using Remotion.Linq.Clauses.ResultOperators;
1013
using Remotion.Linq.Parsing;
1114

1215
namespace NHibernate.Linq.Visitors
@@ -219,14 +222,35 @@ protected override Expression VisitConstant(ConstantExpression node)
219222
return node;
220223
}
221224

222-
public override Expression Visit(Expression node)
225+
protected override Expression VisitSubQuery(SubQueryExpression node)
223226
{
224-
if (node is SubQueryExpression subQueryExpression)
227+
// ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of
228+
// ContainsResultOperator where the constant expression is dislocated from the related expression,
229+
// we have to manually link the related expressions.
230+
var containsOperator = node.QueryModel.ResultOperators.OfType<ContainsResultOperator>().FirstOrDefault();
231+
if (containsOperator != null &&
232+
node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference &&
233+
querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause &&
234+
mainFromClause.FromExpression is ConstantExpression constantExpression)
225235
{
226-
subQueryExpression.QueryModel.TransformExpressions(Visit);
236+
VisitConstant(constantExpression);
237+
AddRelatedExpression(constantExpression, Unwrap(Visit(containsOperator.Item)));
238+
// Copy all found MemberExpressions to the constant expression
239+
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
240+
if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set))
241+
{
242+
foreach (var nestedMemberExpression in set)
243+
{
244+
AddRelatedExpression(constantExpression, nestedMemberExpression);
245+
}
246+
}
247+
}
248+
else
249+
{
250+
node.QueryModel.TransformExpressions(Visit);
227251
}
228252

229-
return base.Visit(node);
253+
return node;
230254
}
231255

232256
private void VisitAssign(Expression leftNode, Expression rightNode)

0 commit comments

Comments
 (0)