Skip to content

Commit 60a708e

Browse files
author
Tomas Lukac
committed
move enum equals implementation to IExpressionTransformer because special expression expression behavior
1 parent e08d046 commit 60a708e

File tree

5 files changed

+105
-89
lines changed

5 files changed

+105
-89
lines changed

src/NHibernate.Test/Linq/FunctionTests.cs

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ public void LikeFunctionWithEscapeCharacter()
3636
session.Flush();
3737

3838
var query = (from e in db.Employees
39-
where NHibernate.Linq.SqlMethods.Like(e.FirstName, employeeNameEscaped, escapeChar)
40-
select e).ToList();
39+
where NHibernate.Linq.SqlMethods.Like(e.FirstName, employeeNameEscaped, escapeChar)
40+
select e).ToList();
4141

4242
Assert.That(query.Count, Is.EqualTo(1));
4343
Assert.That(query[0].FirstName, Is.EqualTo(employeeName));
@@ -81,8 +81,8 @@ where NHibernate.Test.Linq.FunctionTests.SqlMethods.Like(e.FirstName, "Ma%et")
8181
public void SubstringFunction2()
8282
{
8383
var query = (from e in db.Employees
84-
where e.FirstName.Substring(0, 2) == "An"
85-
select e).ToList();
84+
where e.FirstName.Substring(0, 2) == "An"
85+
select e).ToList();
8686

8787
Assert.That(query.Count, Is.EqualTo(2));
8888
}
@@ -91,8 +91,8 @@ where e.FirstName.Substring(0, 2) == "An"
9191
public void SubstringFunction1()
9292
{
9393
var query = (from e in db.Employees
94-
where e.FirstName.Substring(3) == "rew"
95-
select e).ToList();
94+
where e.FirstName.Substring(3) == "rew"
95+
select e).ToList();
9696

9797
Assert.That(query.Count, Is.EqualTo(1));
9898
Assert.That(query[0].FirstName, Is.EqualTo("Andrew"));
@@ -130,21 +130,21 @@ public void ReplaceFunction()
130130
var query = from e in db.Employees
131131
where e.FirstName.StartsWith("An")
132132
select new
133-
{
134-
Before = e.FirstName,
135-
// This one call the standard string.Replace, not the extension. The linq registry handles it.
136-
AfterMethod = e.FirstName.Replace("An", "Zan"),
137-
AfterExtension = ExtensionMethods.Replace(e.FirstName, "An", "Zan"),
138-
AfterNamedExtension = e.FirstName.ReplaceExtension("An", "Zan"),
139-
AfterEvaluableExtension = e.FirstName.ReplaceWithEvaluation("An", "Zan"),
140-
AfterEvaluable2Extension = e.FirstName.ReplaceWithEvaluation2("An", "Zan"),
133+
{
134+
Before = e.FirstName,
135+
// This one call the standard string.Replace, not the extension. The linq registry handles it.
136+
AfterMethod = e.FirstName.Replace("An", "Zan"),
137+
AfterExtension = ExtensionMethods.Replace(e.FirstName, "An", "Zan"),
138+
AfterNamedExtension = e.FirstName.ReplaceExtension("An", "Zan"),
139+
AfterEvaluableExtension = e.FirstName.ReplaceWithEvaluation("An", "Zan"),
140+
AfterEvaluable2Extension = e.FirstName.ReplaceWithEvaluation2("An", "Zan"),
141141
BeforeConst = suppliedName,
142-
// This one call the standard string.Replace, not the extension. The linq registry handles it.
143-
AfterMethodConst = suppliedName.Replace("An", "Zan"),
144-
AfterExtensionConst = ExtensionMethods.Replace(suppliedName, "An", "Zan"),
145-
AfterNamedExtensionConst = suppliedName.ReplaceExtension("An", "Zan"),
146-
AfterEvaluableExtensionConst = suppliedName.ReplaceWithEvaluation("An", "Zan"),
147-
AfterEvaluable2ExtensionConst = suppliedName.ReplaceWithEvaluation2("An", "Zan")
142+
// This one call the standard string.Replace, not the extension. The linq registry handles it.
143+
AfterMethodConst = suppliedName.Replace("An", "Zan"),
144+
AfterExtensionConst = ExtensionMethods.Replace(suppliedName, "An", "Zan"),
145+
AfterNamedExtensionConst = suppliedName.ReplaceExtension("An", "Zan"),
146+
AfterEvaluableExtensionConst = suppliedName.ReplaceWithEvaluation("An", "Zan"),
147+
AfterEvaluable2ExtensionConst = suppliedName.ReplaceWithEvaluation2("An", "Zan")
148148
};
149149
var results = query.ToList();
150150
var s = ObjectDumper.Write(results);
@@ -171,12 +171,12 @@ where e.FirstName.StartsWith("An")
171171
// Should cause ReplaceWithEvaluation to fail
172172
suppliedName = null;
173173
var failingQuery = from e in db.Employees
174-
where e.FirstName.StartsWith("An")
175-
select new
176-
{
177-
Before = e.FirstName,
178-
AfterEvaluableExtensionConst = suppliedName.ReplaceWithEvaluation("An", "Zan")
179-
};
174+
where e.FirstName.StartsWith("An")
175+
select new
176+
{
177+
Before = e.FirstName,
178+
AfterEvaluableExtensionConst = suppliedName.ReplaceWithEvaluation("An", "Zan")
179+
};
180180
Assert.That(() => failingQuery.ToList(), Throws.InstanceOf<HibernateException>().And.InnerException.InstanceOf<ArgumentNullException>());
181181
}
182182

@@ -248,7 +248,7 @@ where lowerName.Contains("a")
248248
public void TwoFunctionExpression()
249249
{
250250
var query = from e in db.Employees
251-
where e.FirstName.IndexOf("A") == e.BirthDate.Value.Month
251+
where e.FirstName.IndexOf("A") == e.BirthDate.Value.Month
252252
select e.FirstName;
253253

254254
ObjectDumper.Write(query);
@@ -285,9 +285,9 @@ public void Trim()
285285
{
286286
using (session.BeginTransaction())
287287
{
288-
AnotherEntity ae1 = new AnotherEntity {Input = " hi "};
289-
AnotherEntity ae2 = new AnotherEntity {Input = "hi"};
290-
AnotherEntity ae3 = new AnotherEntity {Input = "heh"};
288+
AnotherEntity ae1 = new AnotherEntity { Input = " hi " };
289+
AnotherEntity ae2 = new AnotherEntity { Input = "hi" };
290+
AnotherEntity ae3 = new AnotherEntity { Input = "heh" };
291291
session.Save(ae1);
292292
session.Save(ae2);
293293
session.Save(ae3);
@@ -303,7 +303,7 @@ public void Trim()
303303

304304
// Check when passed as array
305305
// (the single character parameter is a new overload in .netcoreapp2.0, but not net461 or .netstandard2.0).
306-
Assert.AreEqual(1, session.Query<AnotherEntity>().Count(e => e.Input.Trim(new [] { 'h' }) == "e"));
306+
Assert.AreEqual(1, session.Query<AnotherEntity>().Count(e => e.Input.Trim(new[] { 'h' }) == "e"));
307307
Assert.AreEqual(1, session.Query<AnotherEntity>().Count(e => e.Input.TrimStart(new[] { 'h' }) == "eh"));
308308
Assert.AreEqual(1, session.Query<AnotherEntity>().Count(e => e.Input.TrimEnd(new[] { 'h' }) == "he"));
309309

@@ -316,9 +316,9 @@ public void TrimInitialWhitespace()
316316
{
317317
using (session.BeginTransaction())
318318
{
319-
session.Save(new AnotherEntity {Input = " hi"});
320-
session.Save(new AnotherEntity {Input = "hi"});
321-
session.Save(new AnotherEntity {Input = "heh"});
319+
session.Save(new AnotherEntity { Input = " hi" });
320+
session.Save(new AnotherEntity { Input = "hi" });
321+
session.Save(new AnotherEntity { Input = "heh" });
322322
session.Flush();
323323

324324
Assert.That(session.Query<AnotherEntity>().Count(e => e.Input.TrimStart() == "hi"), Is.EqualTo(2));
@@ -372,7 +372,7 @@ public void WhereBoolConstantEqual()
372372
var query = from item in db.Role
373373
where item.IsActive.Equals(true)
374374
select item;
375-
375+
376376
ObjectDumper.Write(query);
377377
}
378378

@@ -382,7 +382,7 @@ public void WhereBoolConditionEquals()
382382
var query = from item in db.Role
383383
where item.IsActive.Equals(item.Name != null)
384384
select item;
385-
385+
386386
ObjectDumper.Write(query);
387387
}
388388

@@ -392,7 +392,7 @@ public void WhereBoolParameterEqual()
392392
var query = from item in db.Role
393393
where item.IsActive.Equals(1 == 1)
394394
select item;
395-
395+
396396
ObjectDumper.Write(query);
397397
}
398398

@@ -412,8 +412,8 @@ where item.IsActive.Equals(f())
412412
public void WhereLongEqual()
413413
{
414414
var query = from item in db.PatientRecords
415-
where item.Id.Equals(-1)
416-
select item;
415+
where item.Id.Equals(-1)
416+
select item;
417417

418418
ObjectDumper.Write(query);
419419
}
@@ -427,7 +427,7 @@ where item.RegisteredAt.Equals(DateTime.Today)
427427

428428
ObjectDumper.Write(query);
429429
}
430-
430+
431431
[Test]
432432
public void WhereGuidEqual()
433433
{
@@ -436,7 +436,7 @@ where item.Reference.Equals(Guid.Empty)
436436
select item;
437437

438438
ObjectDumper.Write(query);
439-
}
439+
}
440440

441441
[Test]
442442
public void WhereDoubleEqual()
@@ -446,8 +446,8 @@ where item.BodyWeight.Equals(-1)
446446
select item;
447447

448448
ObjectDumper.Write(query);
449-
}
450-
449+
}
450+
451451
[Test]
452452
[Ignore("Not mapped entity")]
453453
public void WhereFloatEqual()
@@ -457,7 +457,7 @@ where item.Float.Equals(-1)
457457
select item;
458458

459459
ObjectDumper.Write(query);
460-
}
460+
}
461461

462462
[Test]
463463
[Ignore("Not mapped entity")]
@@ -499,14 +499,20 @@ where item.Gender.Equals(Gender.Female)
499499
select item;
500500

501501
ObjectDumper.Write(query);
502+
503+
query = from item in db.PatientRecords
504+
where item.Gender.Equals(item.Gender)
505+
select item;
506+
507+
ObjectDumper.Write(query);
502508
}
503509

504510
[Test]
505511
public void WhereEquatableEqual()
506512
{
507513
var query = from item in db.Shippers
508-
where ((IEquatable<Guid>) item.Reference).Equals(Guid.Empty)
509-
select item;
514+
where ((IEquatable<Guid>) item.Reference).Equals(Guid.Empty)
515+
select item;
510516

511517
ObjectDumper.Write(query);
512518
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using System;
2+
using System.Linq.Expressions;
3+
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;
4+
5+
namespace NHibernate.Linq.ExpressionTransformers
6+
{
7+
/// <summary>
8+
/// Transforms <see cref="Enum"/>.Equals method to equality operator.
9+
/// This cannot be done easily in <see cref="Functions.IHqlGeneratorForMethod"/> as Equals operator
10+
/// is boxed to ((object)Enum).Equals((object)EnumValue) expression.
11+
/// </summary>
12+
public class EnumEqualsTransformer : IExpressionTransformer<MethodCallExpression>
13+
{
14+
public ExpressionType[] SupportedExpressionTypes => _supportedExpressionTypes;
15+
16+
private static readonly ExpressionType[] _supportedExpressionTypes = new[]
17+
{
18+
ExpressionType.Call
19+
};
20+
21+
public Expression Transform(MethodCallExpression expression)
22+
{
23+
if (expression.Object?.Type.IsEnum == true &&
24+
expression.Method.Name == nameof(Enum.Equals) &&
25+
expression.Arguments.Count == 1)
26+
{
27+
return Expression.Equal(expression.Object, Unwrap(expression.Arguments[0], expression.Object.Type));
28+
}
29+
30+
return expression;
31+
}
32+
33+
private Expression Unwrap(Expression expression, System.Type type)
34+
{
35+
// 1) unwrap convert operand as convert is converting from enum type to object
36+
if (expression is UnaryExpression u && u.NodeType == ExpressionType.Convert)
37+
{
38+
return u.Operand;
39+
}
40+
41+
// 2) convert constant expression which is of type object
42+
if (expression is ConstantExpression c)
43+
{
44+
return Expression.Convert(c, type);
45+
}
46+
47+
return expression;
48+
}
49+
}
50+
}

src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ public DefaultLinqToHqlGeneratorsRegistry()
3939
this.Merge(new EndsWithGenerator());
4040
this.Merge(new ContainsGenerator());
4141
this.Merge(new EqualsGenerator());
42-
this.Merge(new EnumEqualsGenerator());
4342
this.Merge(new ToUpperGenerator());
4443
this.Merge(new ToLowerGenerator());
4544
this.Merge(new SubStringGenerator());

src/NHibernate/Linq/Functions/EnumEqualsGenerator.cs

Lines changed: 0 additions & 38 deletions
This file was deleted.

src/NHibernate/Linq/NhRelinqQueryParser.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
using System.Linq;
55
using System.Linq.Expressions;
66
using System.Reflection;
7-
using NHibernate.Engine;
87
using NHibernate.Linq.ExpressionTransformers;
98
using NHibernate.Linq.Visitors;
10-
using NHibernate.Param;
119
using NHibernate.Util;
1210
using Remotion.Linq;
1311
using Remotion.Linq.EagerFetching.Parsing;
@@ -27,6 +25,7 @@ static NhRelinqQueryParser()
2725
var transformerRegistry = ExpressionTransformerRegistry.CreateDefault();
2826
transformerRegistry.Register(new RemoveRedundantCast());
2927
transformerRegistry.Register(new SimplifyCompareTransformer());
28+
transformerRegistry.Register(new EnumEqualsTransformer());
3029

3130
// If needing a compound processor for adding other processing, do not use
3231
// ExpressionTreeParser.CreateDefaultProcessor(transformerRegistry), it would
@@ -114,14 +113,14 @@ public NHibernateNodeTypeProvider()
114113
new[] { ReflectHelper.FastGetMethodDefinition(EagerFetchingExtensionMethods.ThenFetch, default(INhFetchRequest<object, object>), default(Expression<Func<object, object>>)) },
115114
typeof(ThenFetchOneExpressionNode));
116115
methodInfoRegistry.Register(
117-
new[] { ReflectHelper.FastGetMethodDefinition( EagerFetchingExtensionMethods.ThenFetchMany, default(INhFetchRequest<object, object>), default(Expression<Func<object, IEnumerable<object>>>)) },
116+
new[] { ReflectHelper.FastGetMethodDefinition(EagerFetchingExtensionMethods.ThenFetchMany, default(INhFetchRequest<object, object>), default(Expression<Func<object, IEnumerable<object>>>)) },
118117
typeof(ThenFetchManyExpressionNode));
119118
methodInfoRegistry.Register(
120119
new[]
121120
{
122121
ReflectHelper.FastGetMethodDefinition(LinqExtensionMethods.WithLock, default(IQueryable<object>), default(LockMode)),
123122
ReflectHelper.FastGetMethodDefinition(LinqExtensionMethods.WithLock, default(IEnumerable<object>), default(LockMode))
124-
},
123+
},
125124
typeof(LockExpressionNode));
126125

127126
var nodeTypeProvider = ExpressionTreeParser.CreateDefaultNodeTypeProvider();

0 commit comments

Comments
 (0)