Skip to content

Commit 4d47799

Browse files
bahusoidhazzik
authored andcommitted
Support LINQ GroupBy subquery (#2210)
1 parent 4138171 commit 4d47799

File tree

5 files changed

+98
-4
lines changed

5 files changed

+98
-4
lines changed

src/NHibernate.Test/Async/Linq/NestedSelectsTests.cs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,5 +362,38 @@ public async Task OrdersIdWithOrderLinesNestedWhereIdAsync()
362362
Assert.That(orders.Count, Is.EqualTo(830));
363363
Assert.That(orders[0].OrderLinesIds, Is.Empty);
364364
}
365+
366+
[Test]
367+
public async Task NoNestedSelects_AnyOnGroupBySubqueryAsync()
368+
{
369+
var subQuery = from vms in db.Animals
370+
group vms by vms.Father
371+
into vmReqs
372+
select vmReqs.Max(x => x.Id);
373+
374+
var outerQuery = from vm in db.Animals
375+
where subQuery.Any(x => vm.Id == x)
376+
select vm;
377+
var animals = await (outerQuery.ToListAsync());
378+
Assert.That(animals.Count, Is.EqualTo(2));
379+
}
380+
381+
//NH-3155
382+
[Test]
383+
public async Task NoNestedSelects_ContainsOnGroupBySubqueryAsync()
384+
{
385+
var subQuery = from vms in db.Animals
386+
where vms.BodyWeight > 0
387+
group vms by vms.Father
388+
into vmReqs
389+
select vmReqs.Max(x => x.Id);
390+
391+
var outerQuery = from vm in db.Animals
392+
where subQuery.Contains(vm.Id)
393+
select vm;
394+
395+
var animals = await (outerQuery.ToListAsync());
396+
Assert.That(animals.Count, Is.EqualTo(2));
397+
}
365398
}
366-
}
399+
}

src/NHibernate.Test/Linq/NestedSelectsTests.cs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,5 +364,38 @@ public void OrdersIdWithOrderLinesNestedWhereId()
364364
Assert.That(orders.Count, Is.EqualTo(830));
365365
Assert.That(orders[0].OrderLinesIds, Is.Empty);
366366
}
367+
368+
[Test]
369+
public void NoNestedSelects_AnyOnGroupBySubquery()
370+
{
371+
var subQuery = from vms in db.Animals
372+
group vms by vms.Father
373+
into vmReqs
374+
select vmReqs.Max(x => x.Id);
375+
376+
var outerQuery = from vm in db.Animals
377+
where subQuery.Any(x => vm.Id == x)
378+
select vm;
379+
var animals = outerQuery.ToList();
380+
Assert.That(animals.Count, Is.EqualTo(2));
381+
}
382+
383+
//NH-3155
384+
[Test]
385+
public void NoNestedSelects_ContainsOnGroupBySubquery()
386+
{
387+
var subQuery = from vms in db.Animals
388+
where vms.BodyWeight > 0
389+
group vms by vms.Father
390+
into vmReqs
391+
select vmReqs.Max(x => x.Id);
392+
393+
var outerQuery = from vm in db.Animals
394+
where subQuery.Contains(vm.Id)
395+
select vm;
396+
397+
var animals = outerQuery.ToList();
398+
Assert.That(animals.Count, Is.EqualTo(2));
399+
}
367400
}
368-
}
401+
}

src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ public static class AggregatingGroupByRewriter
3838
typeof(FirstResultOperator),
3939
typeof(SingleResultOperator),
4040
typeof(AnyResultOperator),
41-
typeof(AllResultOperator)
41+
typeof(AllResultOperator),
42+
typeof(ContainsResultOperator),
4243
};
4344

4445
public static void ReWrite(QueryModel queryModel)

src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory
4444
replacements.Add(expression, processed);
4545
}
4646

47+
if (!replacements.Any())
48+
return;
49+
4750
var key = Expression.Property(group, IGroupingKeyProperty);
4851

4952
var expressions = new List<ExpressionHolder>();
@@ -92,6 +95,10 @@ private static Expression ProcessExpression(QueryModel queryModel, ISessionFacto
9295

9396
private static Expression ProcessSubquery(ISessionFactory sessionFactory, ICollection<ExpressionHolder> elementExpression, QueryModel queryModel, Expression @group, QueryModel subQueryModel)
9497
{
98+
var resultTypeOverride = subQueryModel.ResultTypeOverride;
99+
if (resultTypeOverride != null && !resultTypeOverride.IsArray && !resultTypeOverride.IsEnumerableOfT())
100+
return null;
101+
95102
var subQueryMainFromClause = subQueryModel.MainFromClause;
96103

97104
var restrictions = subQueryModel.BodyClauses

src/NHibernate/Linq/Visitors/QueryModelVisitor.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ public override void VisitNhJoinClause(NhJoinClause joinClause, QueryModel query
365365
public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index)
366366
{
367367
PreviousEvaluationType = CurrentEvaluationType;
368-
CurrentEvaluationType = resultOperator.GetOutputDataInfo(PreviousEvaluationType);
368+
CurrentEvaluationType = GetOutputDataInfo(resultOperator, PreviousEvaluationType);
369369

370370
if (resultOperator is ClientSideTransformOperator)
371371
{
@@ -382,6 +382,26 @@ public override void VisitResultOperator(ResultOperatorBase resultOperator, Quer
382382
ResultOperatorMap.Process(resultOperator, this, _hqlTree);
383383
}
384384

385+
private static IStreamedDataInfo GetOutputDataInfo(ResultOperatorBase resultOperator, IStreamedDataInfo evaluationType)
386+
{
387+
//ContainsResultOperator contains data integrity check so for `values.Contains(x)` it checks that 'x' is proper type to be used inside 'values.Contains()'
388+
//Due to some reasons (possibly NH expression rewritings) those types might be incompatible (known case NH-3155 - group by subquery). So resultOperator.GetOutputDataInfo throws something like:
389+
//System.ArgumentException : The items of the input sequence of type 'System.Linq.IGrouping`2[System.Object[],EntityType]' are not compatible with the item expression of type 'System.Int32'.
390+
//Parameter name: inputInfo
391+
//at Remotion.Linq.Clauses.ResultOperators.ContainsResultOperator.GetOutputDataInfo(StreamedSequenceInfo inputInfo)
392+
//But in this place we don't really care about types involving inside expression, all we need to know is operation result which is bool for Contains
393+
//So let's skip possible type exception mismatch if it allows to generate proper SQL
394+
switch (resultOperator)
395+
{
396+
case ContainsResultOperator _:
397+
case AnyResultOperator _:
398+
case AllResultOperator _:
399+
return new StreamedScalarValueInfo(typeof(bool));
400+
}
401+
402+
return resultOperator.GetOutputDataInfo(evaluationType);
403+
}
404+
385405
public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel)
386406
{
387407
CurrentEvaluationType = selectClause.GetOutputDataInfo();

0 commit comments

Comments
 (0)