Skip to content

Commit 262ac22

Browse files
committed
Merge pull request #392 from gliljas/NH-3747
NH-3747 - Unit tests and fixes for aggregate predicates
2 parents 3140fe4 + 36e86ca commit 262ac22

File tree

9 files changed

+234
-13
lines changed

9 files changed

+234
-13
lines changed

src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs

Lines changed: 100 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -285,19 +285,19 @@ public void SelectTupleKeyCountOfOrderLines()
285285
group o by o.OrderDate
286286
into g
287287
select new
288-
{
289-
g.Key,
290-
Count = g.SelectMany(x => x.OrderLines).Count()
291-
}).ToList();
288+
{
289+
g.Key,
290+
Count = g.SelectMany(x => x.OrderLines).Count()
291+
}).ToList();
292292

293293
var query = (from o in db.Orders
294294
group o by o.OrderDate
295295
into g
296296
select new
297-
{
298-
g.Key,
299-
Count = g.SelectMany(x => x.OrderLines).Count()
300-
}).ToList();
297+
{
298+
g.Key,
299+
Count = g.SelectMany(x => x.OrderLines).Count()
300+
}).ToList();
301301

302302
Assert.That(query.Count, Is.EqualTo(481));
303303
Assert.That(query, Is.EquivalentTo(list));
@@ -333,9 +333,9 @@ public void GroupByAndTake2()
333333
{
334334
//NH-2566
335335
var results = (from o in db.Orders
336-
group o by o.Customer
337-
into g
338-
select g.Key.CustomerId)
336+
group o by o.Customer
337+
into g
338+
select g.Key.CustomerId)
339339
.OrderBy(customerId => customerId)
340340
.Skip(10)
341341
.Take(10)
@@ -418,6 +418,95 @@ public void SelectSingleOrDefaultElementFromProductsGroupedByUnitPrice()
418418
Assert.That(result.Count, Is.EqualTo(1));
419419
}
420420

421+
[Test]
422+
public void ProjectingCountWithPredicate()
423+
{
424+
var result = db.Products
425+
.GroupBy(x => x.Supplier.CompanyName)
426+
.Select(x => new { x.Key, Count = x.Count(y => y.UnitPrice == 9.50M) })
427+
.OrderByDescending(x => x.Key)
428+
.First();
429+
430+
Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
431+
Assert.That(result.Count, Is.EqualTo(1));
432+
}
433+
434+
[Test]
435+
public void FilteredByCountWithPredicate()
436+
{
437+
var result = db.Products
438+
.GroupBy(x => x.Supplier.CompanyName)
439+
.Where(x => x.Count(y => y.UnitPrice == 12.75M) == 1)
440+
.Select(x => new { x.Key, Count = x.Count() })
441+
.First();
442+
443+
Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
444+
Assert.That(result.Count, Is.EqualTo(2));
445+
}
446+
447+
[Test]
448+
public void FilteredByCountFromSubQuery()
449+
{
450+
//Not really an aggregate filter, but included to ensure that this kind of query still works
451+
var result = db.Products
452+
.GroupBy(x => x.Supplier.CompanyName)
453+
.Where(x => db.Products.Count(y => y.Supplier.CompanyName==x.Key && y.UnitPrice == 12.75M) == 1)
454+
.Select(x => new { x.Key, Count = x.Count() })
455+
.First();
456+
457+
Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
458+
Assert.That(result.Count, Is.EqualTo(2));
459+
}
460+
461+
[Test]
462+
public void FilteredByAndProjectingSumWithPredicate()
463+
{
464+
var result = db.Products
465+
.GroupBy(x => x.Supplier.CompanyName)
466+
.Where(x => x.Sum(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M) == 12.75M)
467+
.Select(x => new { x.Key, Sum = x.Sum(y => y.UnitPrice) })
468+
.First();
469+
470+
Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
471+
Assert.That(result.Sum, Is.EqualTo(12.75M + 9.50M));
472+
}
473+
474+
[Test]
475+
public void FilteredByKeyAndProjectedWithAggregatePredicates()
476+
{
477+
var result = db.Products
478+
.GroupBy(x => x.Supplier.CompanyName)
479+
.Where(x => x.Key == "Zaanse Snoepfabriek")
480+
.Select(x => new { x.Key,
481+
Sum = x.Sum(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M),
482+
Avg = x.Average(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M),
483+
Count = x.Count(y => y.UnitPrice == 12.75M),
484+
Max = x.Max(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M),
485+
Min = x.Min(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M)
486+
})
487+
.First();
488+
489+
Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
490+
Assert.That(result.Sum, Is.EqualTo(12.75M));
491+
Assert.That(result.Count, Is.EqualTo(1));
492+
Assert.That(result.Avg, Is.EqualTo(12.75M/2));
493+
Assert.That(result.Max, Is.EqualTo(12.75M));
494+
Assert.That(result.Min, Is.EqualTo(0M));
495+
}
496+
497+
[Test]
498+
public void ProjectingWithSubQueriesFilteredByTheAggregateKey()
499+
{
500+
var result=db.Products.GroupBy(x => x.Supplier.Address.Country)
501+
.OrderBy(x=>x.Key)
502+
.Select(x => new { x.Key, MaxFreight = db.Orders.Where(y => y.ShippingAddress.Country == x.Key).Max(y => y.Freight), FirstOrder = db.Orders.Where(o => o.Employee.FirstName.StartsWith("A")).OrderBy(o => o.OrderId).Select(y => y.OrderId).First() })
503+
.ToList();
504+
505+
Assert.That(result.Count,Is.EqualTo(16));
506+
Assert.That(result[15].MaxFreight, Is.EqualTo(830.75M));
507+
Assert.That(result[15].FirstOrder, Is.EqualTo(10255));
508+
}
509+
421510
private static void CheckGrouping<TKey, TElement>(IEnumerable<IGrouping<TKey, TElement>> groupedItems, Func<TElement, TKey> groupBy)
422511
{
423512
var used = new HashSet<object>();
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using System.Linq;
2+
using System.Linq.Expressions;
3+
using Remotion.Linq.Clauses;
4+
using Remotion.Linq.Clauses.Expressions;
5+
using Remotion.Linq.Clauses.ResultOperators;
6+
7+
namespace NHibernate.Linq
8+
{
9+
public static class ExpressionExtensions
10+
{
11+
public static bool IsGroupingKey(this MemberExpression expression)
12+
{
13+
return expression.Member.Name == "Key" && expression.Member.DeclaringType!=null &&
14+
expression.Member.DeclaringType.IsGenericType && expression.Member.DeclaringType.GetGenericTypeDefinition() == typeof(IGrouping<,>);
15+
}
16+
17+
public static bool IsGroupingKeyOf(this MemberExpression expression,GroupResultOperator groupBy)
18+
{
19+
if (!expression.IsGroupingKey())
20+
{
21+
return false;
22+
}
23+
24+
var querySource = expression.Expression as QuerySourceReferenceExpression;
25+
if (querySource == null) return false;
26+
27+
var fromClause = querySource.ReferencedQuerySource as MainFromClause;
28+
if (fromClause == null) return false;
29+
30+
var query = fromClause.FromExpression as SubQueryExpression;
31+
if (query == null) return false;
32+
33+
return query.QueryModel.ResultOperators.Contains(groupBy);
34+
}
35+
}
36+
}

src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using System;
2+
using System.Linq;
23
using System.Linq.Expressions;
4+
using NHibernate.Linq.Expressions;
35
using Remotion.Linq;
46
using Remotion.Linq.Clauses;
57
using Remotion.Linq.Clauses.Expressions;
@@ -9,6 +11,7 @@
911

1012
namespace NHibernate.Linq.GroupBy
1113
{
14+
//This should be renamed. It handles entire querymodels, not just select clauses
1215
internal class GroupBySelectClauseRewriter : ExpressionTreeVisitor
1316
{
1417
public static Expression ReWrite(Expression expression, GroupResultOperator groupBy, QueryModel model)
@@ -43,7 +46,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression)
4346
return base.VisitMemberExpression(expression);
4447
}
4548

46-
if (expression.Member.Name == "Key")
49+
if (expression.IsGroupingKeyOf(_groupBy))
4750
{
4851
return _groupBy.KeySelector;
4952
}
@@ -105,6 +108,34 @@ private bool IsMemberOfModel(MemberExpression expression)
105108

106109
protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
107110
{
111+
//If the subquery is a Count(*) aggregate with a condition
112+
if (expression.QueryModel.MainFromClause.FromExpression.Type == _groupBy.ItemType)
113+
{
114+
var where = expression.QueryModel.BodyClauses.OfType<WhereClause>().FirstOrDefault();
115+
NhCountExpression countExpression;
116+
if (where != null && (countExpression = expression.QueryModel.SelectClause.Selector as NhCountExpression) !=
117+
null && countExpression.Expression.NodeType == (ExpressionType)NhExpressionType.Star)
118+
{
119+
//return it as a CASE [column] WHEN [predicate] THEN 1 ELSE NULL END
120+
return
121+
countExpression.CreateNew(Expression.Condition(where.Predicate, Expression.Constant(1, typeof(int?)),
122+
Expression.Constant(null, typeof(int?))));
123+
124+
}
125+
}
126+
127+
//In the subquery body clauses, references to the grouping key should be restored to the KeySelector expression. NOT the resolved value.
128+
//This feels a bit backwards, but solving it here is probably a smaller operation than fixing the previous rewriting
129+
if (expression.QueryModel.BodyClauses.Any())
130+
{
131+
foreach (var bodyClause in expression.QueryModel.BodyClauses)
132+
{
133+
bodyClause.TransformExpressions((e) => new KeySelectorVisitor(_groupBy).VisitExpression(e));
134+
}
135+
return base.VisitSubQueryExpression(expression);
136+
}
137+
138+
108139
// TODO - is this safe? All we are extracting is the select clause from the sub-query. Assumes that everything
109140
// else in the subquery has been removed. If there were two subqueries, one aggregating & one not, this may not be a
110141
// valid assumption. Should probably be passed a list of aggregating subqueries that we are flattening so that we can check...

src/NHibernate/Linq/GroupBy/IsNonAggregatingGroupByDetectionVisitor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public bool IsNonAggregatingGroupBy(Expression expression)
2323

2424
protected override Expression VisitMemberExpression(MemberExpression expression)
2525
{
26-
return expression.Member.Name == "Key"
26+
return expression.IsGroupingKey()
2727
? expression
2828
: base.VisitMemberExpression(expression);
2929
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System.Linq.Expressions;
2+
using Remotion.Linq.Clauses.ResultOperators;
3+
using Remotion.Linq.Parsing;
4+
5+
namespace NHibernate.Linq.GroupBy
6+
{
7+
internal class KeySelectorVisitor : ExpressionTreeVisitor
8+
{
9+
private readonly GroupResultOperator _groupBy;
10+
11+
public KeySelectorVisitor(GroupResultOperator groupBy)
12+
{
13+
_groupBy = groupBy;
14+
}
15+
16+
protected override Expression VisitMemberExpression(MemberExpression expression)
17+
{
18+
if (expression.IsGroupingKeyOf(_groupBy))
19+
{
20+
return _groupBy.KeySelector;
21+
}
22+
return base.VisitMemberExpression(expression);
23+
}
24+
}
25+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System.Linq.Expressions;
2+
using Remotion.Linq.Clauses.Expressions;
3+
using Remotion.Linq.Parsing;
4+
5+
namespace NHibernate.Linq.Visitors
6+
{
7+
public class QueryExpressionSourceIdentifer : ExpressionTreeVisitor
8+
{
9+
private readonly QuerySourceIdentifier _identifier;
10+
11+
public QueryExpressionSourceIdentifer(QuerySourceIdentifier identifier)
12+
{
13+
_identifier = identifier;
14+
}
15+
16+
protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
17+
{
18+
_identifier.VisitQueryModel(expression.QueryModel);
19+
return base.VisitSubQueryExpression(expression);
20+
}
21+
}
22+
}

src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ public override void VisitResultOperator(ResultOperatorBase resultOperator, Quer
5959
_namer.Add(groupBy);
6060
}
6161

62+
public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel)
63+
{
64+
//Find nested query sources
65+
new QueryExpressionSourceIdentifer(this).VisitExpression(selectClause.Selector);
66+
}
67+
6268
public QuerySourceNamer Namer { get { return _namer; } }
6369
}
6470
}

src/NHibernate/Linq/Visitors/SelectClauseNominator.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,16 @@ private bool IsRegisteredFunction(Expression expression)
115115
methodCallExpression.Object.NodeType != ExpressionType.Constant; // does not belong to parameter
116116
}
117117
}
118+
else if (expression.NodeType == (ExpressionType)NhExpressionType.Sum ||
119+
expression.NodeType == (ExpressionType)NhExpressionType.Count ||
120+
expression.NodeType == (ExpressionType)NhExpressionType.Average ||
121+
expression.NodeType == (ExpressionType)NhExpressionType.Max ||
122+
expression.NodeType == (ExpressionType)NhExpressionType.Min)
123+
{
124+
return true;
125+
}
118126
return false;
127+
119128
}
120129

121130
private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool projectConstantsInHql)

src/NHibernate/NHibernate.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,13 +284,15 @@
284284
<Compile Include="Linq\Clauses\NhHavingClause.cs" />
285285
<Compile Include="Linq\Clauses\NhJoinClause.cs" />
286286
<Compile Include="Linq\Clauses\NhWithClause.cs" />
287+
<Compile Include="Linq\ExpressionExtensions.cs" />
287288
<Compile Include="Linq\ExpressionTransformers\RemoveCharToIntConversion.cs" />
288289
<Compile Include="Linq\ExpressionTransformers\RemoveRedundantCast.cs" />
289290
<Compile Include="Linq\Functions\ConvertGenerator.cs" />
290291
<Compile Include="Linq\Functions\GetValueOrDefaultGenerator.cs" />
291292
<Compile Include="Linq\Functions\MathGenerator.cs" />
292293
<Compile Include="Linq\Functions\DictionaryGenerator.cs" />
293294
<Compile Include="Linq\Functions\EqualsGenerator.cs" />
295+
<Compile Include="Linq\GroupBy\KeySelectorVisitor.cs" />
294296
<Compile Include="Linq\GroupBy\PagingRewriter.cs" />
295297
<Compile Include="Linq\NestedSelects\NestedSelectDetector.cs" />
296298
<Compile Include="Linq\NestedSelects\Tuple.cs" />
@@ -311,6 +313,7 @@
311313
<Compile Include="Linq\Visitors\JoinBuilder.cs" />
312314
<Compile Include="Linq\Visitors\PagingRewriterSelectClauseVisitor.cs" />
313315
<Compile Include="Linq\Visitors\PossibleValueSet.cs" />
316+
<Compile Include="Linq\Visitors\QueryExpressionSourceIdentifer.cs" />
314317
<Compile Include="Linq\Visitors\QuerySourceIdentifier.cs" />
315318
<Compile Include="Linq\Visitors\ResultOperatorAndOrderByJoinDetector.cs" />
316319
<Compile Include="Linq\Visitors\ResultOperatorProcessors\ProcessTimeout.cs" />

0 commit comments

Comments
 (0)