Skip to content

NH-3747 - Unit tests and fixes for aggregate predicates #392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 22, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 100 additions & 11 deletions src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,19 @@ public void SelectTupleKeyCountOfOrderLines()
group o by o.OrderDate
into g
select new
{
g.Key,
Count = g.SelectMany(x => x.OrderLines).Count()
}).ToList();
{
g.Key,
Count = g.SelectMany(x => x.OrderLines).Count()
}).ToList();

var query = (from o in db.Orders
group o by o.OrderDate
into g
select new
{
g.Key,
Count = g.SelectMany(x => x.OrderLines).Count()
}).ToList();
{
g.Key,
Count = g.SelectMany(x => x.OrderLines).Count()
}).ToList();

Assert.That(query.Count, Is.EqualTo(481));
Assert.That(query, Is.EquivalentTo(list));
Expand Down Expand Up @@ -333,9 +333,9 @@ public void GroupByAndTake2()
{
//NH-2566
var results = (from o in db.Orders
group o by o.Customer
into g
select g.Key.CustomerId)
group o by o.Customer
into g
select g.Key.CustomerId)
.OrderBy(customerId => customerId)
.Skip(10)
.Take(10)
Expand Down Expand Up @@ -418,6 +418,95 @@ public void SelectSingleOrDefaultElementFromProductsGroupedByUnitPrice()
Assert.That(result.Count, Is.EqualTo(1));
}

[Test]
public void ProjectingCountWithPredicate()
{
var result = db.Products
.GroupBy(x => x.Supplier.CompanyName)
.Select(x => new { x.Key, Count = x.Count(y => y.UnitPrice == 9.50M) })
.OrderByDescending(x => x.Key)
.First();

Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
Assert.That(result.Count, Is.EqualTo(1));
}

[Test]
public void FilteredByCountWithPredicate()
{
var result = db.Products
.GroupBy(x => x.Supplier.CompanyName)
.Where(x => x.Count(y => y.UnitPrice == 12.75M) == 1)
.Select(x => new { x.Key, Count = x.Count() })
.First();

Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
Assert.That(result.Count, Is.EqualTo(2));
}

[Test]
public void FilteredByCountFromSubQuery()
{
//Not really an aggregate filter, but included to ensure that this kind of query still works
var result = db.Products
.GroupBy(x => x.Supplier.CompanyName)
.Where(x => db.Products.Count(y => y.Supplier.CompanyName==x.Key && y.UnitPrice == 12.75M) == 1)
.Select(x => new { x.Key, Count = x.Count() })
.First();

Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
Assert.That(result.Count, Is.EqualTo(2));
}

[Test]
public void FilteredByAndProjectingSumWithPredicate()
{
var result = db.Products
.GroupBy(x => x.Supplier.CompanyName)
.Where(x => x.Sum(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M) == 12.75M)
.Select(x => new { x.Key, Sum = x.Sum(y => y.UnitPrice) })
.First();

Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
Assert.That(result.Sum, Is.EqualTo(12.75M + 9.50M));
}

[Test]
public void FilteredByKeyAndProjectedWithAggregatePredicates()
{
var result = db.Products
.GroupBy(x => x.Supplier.CompanyName)
.Where(x => x.Key == "Zaanse Snoepfabriek")
.Select(x => new { x.Key,
Sum = x.Sum(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M),
Avg = x.Average(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M),
Count = x.Count(y => y.UnitPrice == 12.75M),
Max = x.Max(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M),
Min = x.Min(y => y.UnitPrice == 12.75M ? y.UnitPrice : 0M)
})
.First();

Assert.That(result.Key, Is.EqualTo("Zaanse Snoepfabriek"));
Assert.That(result.Sum, Is.EqualTo(12.75M));
Assert.That(result.Count, Is.EqualTo(1));
Assert.That(result.Avg, Is.EqualTo(12.75M/2));
Assert.That(result.Max, Is.EqualTo(12.75M));
Assert.That(result.Min, Is.EqualTo(0M));
}

[Test]
public void ProjectingWithSubQueriesFilteredByTheAggregateKey()
{
var result=db.Products.GroupBy(x => x.Supplier.Address.Country)
.OrderBy(x=>x.Key)
.Select(x => new { x.Key, MaxFreight = db.Orders.Where(y => y.ShippingAddress.Country == x.Key).Max(y => y.Freight), FirstOrder = db.Orders.Where(o => o.Employee.FirstName.StartsWith("A")).OrderBy(o => o.OrderId).Select(y => y.OrderId).First() })
.ToList();

Assert.That(result.Count,Is.EqualTo(16));
Assert.That(result[15].MaxFreight, Is.EqualTo(830.75M));
Assert.That(result[15].FirstOrder, Is.EqualTo(10255));
}

private static void CheckGrouping<TKey, TElement>(IEnumerable<IGrouping<TKey, TElement>> groupedItems, Func<TElement, TKey> groupBy)
{
var used = new HashSet<object>();
Expand Down
36 changes: 36 additions & 0 deletions src/NHibernate/Linq/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using System.Linq;
using System.Linq.Expressions;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Clauses.ResultOperators;

namespace NHibernate.Linq
{
public static class ExpressionExtensions
{
public static bool IsGroupingKey(this MemberExpression expression)
{
return expression.Member.Name == "Key" && expression.Member.DeclaringType!=null &&
expression.Member.DeclaringType.IsGenericType && expression.Member.DeclaringType.GetGenericTypeDefinition() == typeof(IGrouping<,>);
}

public static bool IsGroupingKeyOf(this MemberExpression expression,GroupResultOperator groupBy)
{
if (!expression.IsGroupingKey())
{
return false;
}

var querySource = expression.Expression as QuerySourceReferenceExpression;
if (querySource == null) return false;

var fromClause = querySource.ReferencedQuerySource as MainFromClause;
if (fromClause == null) return false;

var query = fromClause.FromExpression as SubQueryExpression;
if (query == null) return false;

return query.QueryModel.ResultOperators.Contains(groupBy);
}
}
}
33 changes: 32 additions & 1 deletion src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using System.Linq;
using System.Linq.Expressions;
using NHibernate.Linq.Expressions;
using Remotion.Linq;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
Expand All @@ -9,6 +11,7 @@

namespace NHibernate.Linq.GroupBy
{
//This should be renamed. It handles entire querymodels, not just select clauses
internal class GroupBySelectClauseRewriter : ExpressionTreeVisitor
{
public static Expression ReWrite(Expression expression, GroupResultOperator groupBy, QueryModel model)
Expand Down Expand Up @@ -43,7 +46,7 @@ protected override Expression VisitMemberExpression(MemberExpression expression)
return base.VisitMemberExpression(expression);
}

if (expression.Member.Name == "Key")
if (expression.IsGroupingKeyOf(_groupBy))
{
return _groupBy.KeySelector;
}
Expand Down Expand Up @@ -105,6 +108,34 @@ private bool IsMemberOfModel(MemberExpression expression)

protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
{
//If the subquery is a Count(*) aggregate with a condition
if (expression.QueryModel.MainFromClause.FromExpression.Type == _groupBy.ItemType)
{
var where = expression.QueryModel.BodyClauses.OfType<WhereClause>().FirstOrDefault();
NhCountExpression countExpression;
if (where != null && (countExpression = expression.QueryModel.SelectClause.Selector as NhCountExpression) !=
null && countExpression.Expression.NodeType == (ExpressionType)NhExpressionType.Star)
{
//return it as a CASE [column] WHEN [predicate] THEN 1 ELSE NULL END
return
countExpression.CreateNew(Expression.Condition(where.Predicate, Expression.Constant(1, typeof(int?)),
Expression.Constant(null, typeof(int?))));

}
}

//In the subquery body clauses, references to the grouping key should be restored to the KeySelector expression. NOT the resolved value.
//This feels a bit backwards, but solving it here is probably a smaller operation than fixing the previous rewriting
if (expression.QueryModel.BodyClauses.Any())
{
foreach (var bodyClause in expression.QueryModel.BodyClauses)
{
bodyClause.TransformExpressions((e) => new KeySelectorVisitor(_groupBy).VisitExpression(e));
}
return base.VisitSubQueryExpression(expression);
}


// TODO - is this safe? All we are extracting is the select clause from the sub-query. Assumes that everything
// else in the subquery has been removed. If there were two subqueries, one aggregating & one not, this may not be a
// valid assumption. Should probably be passed a list of aggregating subqueries that we are flattening so that we can check...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public bool IsNonAggregatingGroupBy(Expression expression)

protected override Expression VisitMemberExpression(MemberExpression expression)
{
return expression.Member.Name == "Key"
return expression.IsGroupingKey()
? expression
: base.VisitMemberExpression(expression);
}
Expand Down
25 changes: 25 additions & 0 deletions src/NHibernate/Linq/GroupBy/KeySelectorVisitor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using System.Linq.Expressions;
using Remotion.Linq.Clauses.ResultOperators;
using Remotion.Linq.Parsing;

namespace NHibernate.Linq.GroupBy
{
internal class KeySelectorVisitor : ExpressionTreeVisitor
{
private readonly GroupResultOperator _groupBy;

public KeySelectorVisitor(GroupResultOperator groupBy)
{
_groupBy = groupBy;
}

protected override Expression VisitMemberExpression(MemberExpression expression)
{
if (expression.IsGroupingKeyOf(_groupBy))
{
return _groupBy.KeySelector;
}
return base.VisitMemberExpression(expression);
}
}
}
22 changes: 22 additions & 0 deletions src/NHibernate/Linq/Visitors/QueryExpressionSourceIdentifer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.Linq.Expressions;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Parsing;

namespace NHibernate.Linq.Visitors
{
public class QueryExpressionSourceIdentifer : ExpressionTreeVisitor
{
private readonly QuerySourceIdentifier _identifier;

public QueryExpressionSourceIdentifer(QuerySourceIdentifier identifier)
{
_identifier = identifier;
}

protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
{
_identifier.VisitQueryModel(expression.QueryModel);
return base.VisitSubQueryExpression(expression);
}
}
}
6 changes: 6 additions & 0 deletions src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ public override void VisitResultOperator(ResultOperatorBase resultOperator, Quer
_namer.Add(groupBy);
}

public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel)
{
//Find nested query sources
new QueryExpressionSourceIdentifer(this).VisitExpression(selectClause.Selector);
}

public QuerySourceNamer Namer { get { return _namer; } }
}
}
9 changes: 9 additions & 0 deletions src/NHibernate/Linq/Visitors/SelectClauseNominator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,16 @@ private bool IsRegisteredFunction(Expression expression)
methodCallExpression.Object.NodeType != ExpressionType.Constant; // does not belong to parameter
}
}
else if (expression.NodeType == (ExpressionType)NhExpressionType.Sum ||
expression.NodeType == (ExpressionType)NhExpressionType.Count ||
expression.NodeType == (ExpressionType)NhExpressionType.Average ||
expression.NodeType == (ExpressionType)NhExpressionType.Max ||
expression.NodeType == (ExpressionType)NhExpressionType.Min)
{
return true;
}
return false;

}

private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool projectConstantsInHql)
Expand Down
3 changes: 3 additions & 0 deletions src/NHibernate/NHibernate.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,15 @@
<Compile Include="Linq\Clauses\NhHavingClause.cs" />
<Compile Include="Linq\Clauses\NhJoinClause.cs" />
<Compile Include="Linq\Clauses\NhWithClause.cs" />
<Compile Include="Linq\ExpressionExtensions.cs" />
<Compile Include="Linq\ExpressionTransformers\RemoveCharToIntConversion.cs" />
<Compile Include="Linq\ExpressionTransformers\RemoveRedundantCast.cs" />
<Compile Include="Linq\Functions\ConvertGenerator.cs" />
<Compile Include="Linq\Functions\GetValueOrDefaultGenerator.cs" />
<Compile Include="Linq\Functions\MathGenerator.cs" />
<Compile Include="Linq\Functions\DictionaryGenerator.cs" />
<Compile Include="Linq\Functions\EqualsGenerator.cs" />
<Compile Include="Linq\GroupBy\KeySelectorVisitor.cs" />
<Compile Include="Linq\GroupBy\PagingRewriter.cs" />
<Compile Include="Linq\NestedSelects\NestedSelectDetector.cs" />
<Compile Include="Linq\NestedSelects\Tuple.cs" />
Expand All @@ -310,6 +312,7 @@
<Compile Include="Linq\Visitors\JoinBuilder.cs" />
<Compile Include="Linq\Visitors\PagingRewriterSelectClauseVisitor.cs" />
<Compile Include="Linq\Visitors\PossibleValueSet.cs" />
<Compile Include="Linq\Visitors\QueryExpressionSourceIdentifer.cs" />
<Compile Include="Linq\Visitors\QuerySourceIdentifier.cs" />
<Compile Include="Linq\Visitors\ResultOperatorAndOrderByJoinDetector.cs" />
<Compile Include="Linq\Visitors\ResultOperatorProcessors\ProcessTimeout.cs" />
Expand Down