Skip to content

Commit 7695e5c

Browse files
author
Tomas Lukac
committed
Add enum HasFlag method support #829
1 parent f560f7d commit 7695e5c

File tree

7 files changed

+162
-3
lines changed

7 files changed

+162
-3
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using System.Linq;
2+
using NUnit.Framework;
3+
4+
namespace NHibernate.Test.NHSpecificTest.GH3365
5+
{
6+
[TestFixture]
7+
public class Fixture : BugTestCase
8+
{
9+
protected override void OnSetUp()
10+
{
11+
using (var session = OpenSession())
12+
using (var transaction = session.BeginTransaction())
13+
{
14+
15+
var e1 = new Parent { Type = TestEnum.A | TestEnum.C };
16+
session.Save(e1);
17+
18+
var e2 = new Child { Type = TestEnum.D, Parent = e1 };
19+
session.Save(e2);
20+
21+
var e3 = new Child { Type = TestEnum.C, Parent = e1 };
22+
session.Save(e3);
23+
24+
session.Flush();
25+
transaction.Commit();
26+
}
27+
}
28+
29+
[Test]
30+
public void SelectClass()
31+
{
32+
using (var session = OpenSession())
33+
using (session.BeginTransaction())
34+
{
35+
var resultFound = session.Query<Parent>().Where(x => x.Type.HasFlag(TestEnum.A)).FirstOrDefault();
36+
37+
var resultNotFound = session.Query<Parent>().Where(x => x.Type.HasFlag(TestEnum.D)).FirstOrDefault();
38+
39+
Assert.That(resultFound, Is.Not.Null);
40+
Assert.That(resultNotFound, Is.Null);
41+
}
42+
}
43+
44+
[Test]
45+
public void SelectChildClassContainedInParent()
46+
{
47+
using (var session = OpenSession())
48+
using (session.BeginTransaction())
49+
{
50+
var result = session.Query<Child>().Where(x => x.Parent.Type.HasFlag(x.Type)).FirstOrDefault();
51+
52+
Assert.That(result, Is.Not.Null);
53+
}
54+
}
55+
56+
protected override void OnTearDown()
57+
{
58+
base.OnTearDown();
59+
using (ISession session = this.OpenSession())
60+
{
61+
foreach (var entity in new[] { nameof(Child), nameof(Parent) })
62+
{
63+
session.Delete($"from {entity}");
64+
session.Flush();
65+
}
66+
}
67+
}
68+
}
69+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
<?xml version="1.0" encoding="utf-8" ?>
2+
<hibernate-mapping xmlns="urn:nhibernate-mapping-2.2" assembly="NHibernate.Test" namespace="NHibernate.Test.NHSpecificTest.GH3365">
3+
<class name="Parent" table="parent">
4+
<id name="Id" generator="guid.comb" />
5+
<property name="Type" type="NHibernate.Type.EnumType`1[[NHibernate.Test.NHSpecificTest.GH3365.TestEnum, NHibernate.Test]], NHibernate"/>
6+
<bag name="Children" lazy="true">
7+
<key column="ParentId" />
8+
<one-to-many class="Child" />
9+
</bag>
10+
</class>
11+
12+
<class name="Child" table="Child">
13+
<id name="Id" generator="guid.comb" />
14+
<property name="Type" type="NHibernate.Type.EnumType`1[[NHibernate.Test.NHSpecificTest.GH3365.TestEnum, NHibernate.Test]], NHibernate"/>
15+
<many-to-one name="Parent" column="ParentId" not-null="true" />
16+
</class>
17+
</hibernate-mapping>
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace NHibernate.Test.NHSpecificTest.GH3365
5+
{
6+
public class Parent
7+
{
8+
public virtual Guid Id { get; set; }
9+
10+
public virtual TestEnum Type { get; set; }
11+
12+
public virtual IList<Child> Children { get; set; } = new List<Child>();
13+
}
14+
15+
public class Child
16+
{
17+
public virtual Guid Id { get; set; }
18+
19+
public virtual TestEnum Type { get; set; }
20+
21+
public virtual Parent Parent { get; set; }
22+
}
23+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
3+
namespace NHibernate.Test.NHSpecificTest.GH3365
4+
{
5+
[Flags]
6+
public enum TestEnum
7+
{
8+
A = 1 << 0,
9+
B = 1 << 1,
10+
C = 1 << 2,
11+
D = 1 << 3
12+
}
13+
}

src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
using System;
12
using System.Linq.Expressions;
23
using NHibernate.Util;
34
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;
45

56
namespace NHibernate.Linq.ExpressionTransformers
67
{
78
/// <summary>
8-
/// Remove redundant casts to the same type or to superclass (upcast) in <see cref="ExpressionType.Convert"/>, <see cref=" ExpressionType.ConvertChecked"/>
9-
/// and <see cref="ExpressionType.TypeAs"/> <see cref="UnaryExpression"/>s
9+
/// Remove redundant casts to the same type or to superclass (upcast) in <see cref="ExpressionType.Convert"/>, <see cref=" ExpressionType.ConvertChecked"/>
10+
/// and <see cref="ExpressionType.TypeAs"/> <see cref="UnaryExpression"/>s
1011
/// </summary>
1112
public class RemoveRedundantCast : IExpressionTransformer<UnaryExpression>
1213
{
@@ -20,6 +21,7 @@ public class RemoveRedundantCast : IExpressionTransformer<UnaryExpression>
2021
public Expression Transform(UnaryExpression expression)
2122
{
2223
if (expression.Type != typeof(object) &&
24+
expression.Type != typeof(Enum) &&
2325
expression.Type.IsAssignableFrom(expression.Operand.Type) &&
2426
expression.Method == null &&
2527
!expression.IsLiftedToNull)
@@ -29,7 +31,7 @@ public Expression Transform(UnaryExpression expression)
2931

3032
// Reduce double casting (e.g. (long?)(long)3 => (long?)3)
3133
if (expression.Operand.NodeType == ExpressionType.Convert &&
32-
expression.Type.UnwrapIfNullable() == expression.Operand.Type)
34+
expression.Type.UnwrapIfNullable() == expression.Operand.Type)
3335
{
3436
return Expression.Convert(((UnaryExpression) expression.Operand).Operand, expression.Type);
3537
}

src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public DefaultLinqToHqlGeneratorsRegistry()
2525
RegisterGenerator(new ToStringRuntimeMethodHqlGenerator());
2626
RegisterGenerator(new LikeGenerator());
2727
RegisterGenerator(new GetValueOrDefaultGenerator());
28+
RegisterGenerator(new HasFlagGenerator());
2829

2930
RegisterGenerator(new CompareGenerator());
3031
this.Merge(new CompareGenerator());
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
using System.Collections.ObjectModel;
3+
using System.Linq.Expressions;
4+
using System.Reflection;
5+
using NHibernate.Hql.Ast;
6+
using NHibernate.Linq.Visitors;
7+
8+
namespace NHibernate.Linq.Functions
9+
{
10+
internal class HasFlagGenerator : BaseHqlGeneratorForMethod, IRuntimeMethodHqlGenerator
11+
{
12+
private const string _bitAndFunctionName = "band";
13+
14+
public bool SupportsMethod(MethodInfo method)
15+
{
16+
return method.Name == nameof(Enum.HasFlag) && method.DeclaringType == typeof(Enum);
17+
}
18+
19+
public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method)
20+
{
21+
return this;
22+
}
23+
24+
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
25+
{
26+
return treeBuilder.Equality(
27+
treeBuilder.MethodCall(
28+
_bitAndFunctionName,
29+
visitor.Visit(targetObject).AsExpression(),
30+
visitor.Visit(arguments[0]).AsExpression()),
31+
visitor.Visit(arguments[0]).AsExpression());
32+
}
33+
}
34+
}

0 commit comments

Comments
 (0)