Skip to content

Commit 03550fd

Browse files
authored
Linq: add enum HasFlag method support (#3238)
Fixes #829
1 parent 9c4befc commit 03550fd

File tree

8 files changed

+220
-0
lines changed

8 files changed

+220
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// This code was generated by AsyncGenerator.
4+
//
5+
// Changes to this file may cause incorrect behavior and will be lost if
6+
// the code is regenerated.
7+
// </auto-generated>
8+
//------------------------------------------------------------------------------
9+
10+
11+
using System.Linq;
12+
using NUnit.Framework;
13+
using NHibernate.Linq;
14+
15+
namespace NHibernate.Test.NHSpecificTest.GH0829
16+
{
17+
using System.Threading.Tasks;
18+
[TestFixture]
19+
public class FixtureAsync : BugTestCase
20+
{
21+
protected override void OnSetUp()
22+
{
23+
using var session = OpenSession();
24+
using var transaction = session.BeginTransaction();
25+
26+
var e1 = new Parent { Type = TestEnum.A | TestEnum.C };
27+
session.Save(e1);
28+
29+
var e2 = new Child { Type = TestEnum.D, Parent = e1 };
30+
session.Save(e2);
31+
32+
var e3 = new Child { Type = TestEnum.C, Parent = e1 };
33+
session.Save(e3);
34+
35+
transaction.Commit();
36+
}
37+
38+
[Test]
39+
public async Task SelectClassAsync()
40+
{
41+
using var session = OpenSession();
42+
43+
var resultFound = await (session.Query<Parent>().Where(x => x.Type.HasFlag(TestEnum.A)).FirstOrDefaultAsync());
44+
45+
var resultNotFound = await (session.Query<Parent>().Where(x => x.Type.HasFlag(TestEnum.D)).FirstOrDefaultAsync());
46+
47+
Assert.That(resultFound, Is.Not.Null);
48+
Assert.That(resultNotFound, Is.Null);
49+
}
50+
51+
[Test]
52+
public async Task SelectChildClassContainedInParentAsync()
53+
{
54+
using var session = OpenSession();
55+
56+
var result = await (session.Query<Child>().Where(x => x.Parent.Type.HasFlag(x.Type)).FirstOrDefaultAsync());
57+
58+
Assert.That(result, Is.Not.Null);
59+
}
60+
61+
protected override void OnTearDown()
62+
{
63+
using var session = OpenSession();
64+
using var transaction = session.BeginTransaction();
65+
foreach (var entity in new[] { nameof(Child), nameof(Parent) })
66+
{
67+
session.CreateQuery($"delete from {entity}").ExecuteUpdate();
68+
}
69+
70+
transaction.Commit();
71+
}
72+
}
73+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using System.Linq;
2+
using NUnit.Framework;
3+
4+
namespace NHibernate.Test.NHSpecificTest.GH0829
5+
{
6+
[TestFixture]
7+
public class Fixture : BugTestCase
8+
{
9+
protected override void OnSetUp()
10+
{
11+
using var session = OpenSession();
12+
using var transaction = session.BeginTransaction();
13+
14+
var e1 = new Parent { Type = TestEnum.A | TestEnum.C };
15+
session.Save(e1);
16+
17+
var e2 = new Child { Type = TestEnum.D, Parent = e1 };
18+
session.Save(e2);
19+
20+
var e3 = new Child { Type = TestEnum.C, Parent = e1 };
21+
session.Save(e3);
22+
23+
transaction.Commit();
24+
}
25+
26+
[Test]
27+
public void SelectClass()
28+
{
29+
using var session = OpenSession();
30+
31+
var resultFound = session.Query<Parent>().Where(x => x.Type.HasFlag(TestEnum.A)).FirstOrDefault();
32+
33+
var resultNotFound = session.Query<Parent>().Where(x => x.Type.HasFlag(TestEnum.D)).FirstOrDefault();
34+
35+
Assert.That(resultFound, Is.Not.Null);
36+
Assert.That(resultNotFound, Is.Null);
37+
}
38+
39+
[Test]
40+
public void SelectChildClassContainedInParent()
41+
{
42+
using var session = OpenSession();
43+
44+
var result = session.Query<Child>().Where(x => x.Parent.Type.HasFlag(x.Type)).FirstOrDefault();
45+
46+
Assert.That(result, Is.Not.Null);
47+
}
48+
49+
protected override void OnTearDown()
50+
{
51+
using var session = OpenSession();
52+
using var transaction = session.BeginTransaction();
53+
foreach (var entity in new[] { nameof(Child), nameof(Parent) })
54+
{
55+
session.CreateQuery($"delete from {entity}").ExecuteUpdate();
56+
}
57+
58+
transaction.Commit();
59+
}
60+
}
61+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
<?xml version="1.0" encoding="utf-8" ?>
2+
<hibernate-mapping xmlns="urn:nhibernate-mapping-2.2" assembly="NHibernate.Test" namespace="NHibernate.Test.NHSpecificTest.GH0829">
3+
<class name="Parent" table="parent">
4+
<id name="Id" generator="guid.comb" />
5+
<property name="Type" type="NHibernate.Type.EnumType`1[[NHibernate.Test.NHSpecificTest.GH0829.TestEnum, NHibernate.Test]], NHibernate"/>
6+
<bag name="Children" lazy="true">
7+
<key column="ParentId" />
8+
<one-to-many class="Child" />
9+
</bag>
10+
</class>
11+
12+
<class name="Child" table="Child">
13+
<id name="Id" generator="guid.comb" />
14+
<property name="Type" type="NHibernate.Type.EnumType`1[[NHibernate.Test.NHSpecificTest.GH0829.TestEnum, NHibernate.Test]], NHibernate"/>
15+
<many-to-one name="Parent" column="ParentId" not-null="true" />
16+
</class>
17+
</hibernate-mapping>
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace NHibernate.Test.NHSpecificTest.GH0829
5+
{
6+
public class Parent
7+
{
8+
public virtual Guid Id { get; set; }
9+
10+
public virtual TestEnum Type { get; set; }
11+
12+
public virtual IList<Child> Children { get; set; } = new List<Child>();
13+
}
14+
15+
public class Child
16+
{
17+
public virtual Guid Id { get; set; }
18+
19+
public virtual TestEnum Type { get; set; }
20+
21+
public virtual Parent Parent { get; set; }
22+
}
23+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
3+
namespace NHibernate.Test.NHSpecificTest.GH0829
4+
{
5+
[Flags]
6+
public enum TestEnum
7+
{
8+
A = 1 << 0,
9+
B = 1 << 1,
10+
C = 1 << 2,
11+
D = 1 << 3
12+
}
13+
}

src/NHibernate/Linq/ExpressionTransformers/RemoveRedundantCast.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using System.Linq.Expressions;
23
using NHibernate.Util;
34
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;
@@ -20,6 +21,7 @@ public class RemoveRedundantCast : IExpressionTransformer<UnaryExpression>
2021
public Expression Transform(UnaryExpression expression)
2122
{
2223
if (expression.Type != typeof(object) &&
24+
expression.Type != typeof(Enum) &&
2325
expression.Type.IsAssignableFrom(expression.Operand.Type) &&
2426
expression.Method == null &&
2527
!expression.IsLiftedToNull)

src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ public DefaultLinqToHqlGeneratorsRegistry()
6969
this.Merge(new DecimalNegateGenerator());
7070
this.Merge(new RoundGenerator());
7171
this.Merge(new TruncateGenerator());
72+
this.Merge(new HasFlagGenerator());
7273

7374
var indexerGenerator = new ListIndexerGenerator();
7475
RegisterGenerator(indexerGenerator);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using System;
2+
using System.Collections.ObjectModel;
3+
using System.Linq.Expressions;
4+
using System.Reflection;
5+
using NHibernate.Hql.Ast;
6+
using NHibernate.Linq.Visitors;
7+
using NHibernate.Util;
8+
9+
namespace NHibernate.Linq.Functions
10+
{
11+
internal class HasFlagGenerator : BaseHqlGeneratorForMethod
12+
{
13+
private const string _bitAndFunctionName = "band";
14+
15+
public HasFlagGenerator()
16+
{
17+
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition<Enum>(x => x.HasFlag(default)) };
18+
}
19+
20+
public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
21+
{
22+
return treeBuilder.Equality(
23+
treeBuilder.MethodCall(
24+
_bitAndFunctionName,
25+
visitor.Visit(targetObject).AsExpression(),
26+
visitor.Visit(arguments[0]).AsExpression()),
27+
visitor.Visit(arguments[0]).AsExpression());
28+
}
29+
}
30+
}

0 commit comments

Comments
 (0)