Skip to content

Commit e08d046

Browse files
author
Tomas Lukac
committed
Add Enum Equals support
1 parent f560f7d commit e08d046

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

src/NHibernate.Test/Linq/FunctionTests.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,16 @@ where item.Discount.Equals(-1)
491491
ObjectDumper.Write(query);
492492
}
493493

494+
[Test]
495+
public void WhereEnumEqual()
496+
{
497+
var query = from item in db.PatientRecords
498+
where item.Gender.Equals(Gender.Female)
499+
select item;
500+
501+
ObjectDumper.Write(query);
502+
}
503+
494504
[Test]
495505
public void WhereEquatableEqual()
496506
{

src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public DefaultLinqToHqlGeneratorsRegistry()
3939
this.Merge(new EndsWithGenerator());
4040
this.Merge(new ContainsGenerator());
4141
this.Merge(new EqualsGenerator());
42+
this.Merge(new EnumEqualsGenerator());
4243
this.Merge(new ToUpperGenerator());
4344
this.Merge(new ToLowerGenerator());
4445
this.Merge(new SubStringGenerator());
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Collections.ObjectModel;
4+
using System.Linq.Expressions;
5+
using System.Reflection;
6+
using NHibernate.Hql.Ast;
7+
using NHibernate.Linq.Visitors;
8+
using NHibernate.Util;
9+
10+
namespace NHibernate.Linq.Functions
11+
{
12+
/// <summary>
13+
/// Not using <see cref="EqualsGenerator"/> class as there is very special handling of enums in expressions and equality operator
14+
/// </summary>
15+
public class EnumEqualsGenerator : BaseHqlGeneratorForMethod
16+
{
17+
internal static HashSet<MethodInfo> Methods = new HashSet<MethodInfo>
18+
{
19+
ReflectHelper.GetMethodDefinition<Enum>(x => x.Equals(default(object))),
20+
ReflectHelper.GetMethodDefinition<IEquatable<Enum>>(x => x.Equals(default(Enum)))
21+
};
22+
23+
public EnumEqualsGenerator()
24+
{
25+
SupportedMethods = Methods;
26+
}
27+
28+
public override bool AllowsNullableReturnType(MethodInfo method) => false;
29+
30+
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
31+
{
32+
Expression lhs = arguments.Count == 1 ? targetObject : arguments[0];
33+
Expression rhs = arguments.Count == 1 ? arguments[0] : arguments[1];
34+
35+
return treeBuilder.Equality(visitor.Visit(lhs).AsExpression(), visitor.Visit(rhs).AsExpression());
36+
}
37+
}
38+
}

0 commit comments

Comments
 (0)