Skip to content

Commit 0b5826e

Browse files
committed
Reduce cast usage for COUNT aggregate and add support for Mssql count_big
1 parent d4206e2 commit 0b5826e

File tree

17 files changed

+207
-14
lines changed

17 files changed

+207
-14
lines changed

src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using System;
1212
using System.Linq;
1313
using NHibernate.Cfg;
14+
using NHibernate.Dialect;
1415
using NUnit.Framework;
1516
using NHibernate.Linq;
1617

@@ -110,5 +111,50 @@ into temp
110111

111112
Assert.That(result.Count, Is.EqualTo(77));
112113
}
114+
115+
[Test]
116+
public async Task CheckSqlFunctionNameLongCountAsync()
117+
{
118+
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
119+
using (var sqlLog = new SqlLogSpy())
120+
{
121+
var result = await (db.Orders.LongCountAsync());
122+
Assert.That(result, Is.EqualTo(830));
123+
124+
var log = sqlLog.GetWholeLog();
125+
Assert.That(log, Does.Contain($"{name}("));
126+
}
127+
}
128+
129+
[Test]
130+
public async Task CheckSqlFunctionNameForCountAsync()
131+
{
132+
using (var sqlLog = new SqlLogSpy())
133+
{
134+
var result = await (db.Orders.CountAsync());
135+
Assert.That(result, Is.EqualTo(830));
136+
137+
var log = sqlLog.GetWholeLog();
138+
Assert.That(log, Does.Contain("count("));
139+
}
140+
}
141+
142+
[Test]
143+
public async Task CheckMssqlCountCastAsync()
144+
{
145+
if (!(Dialect is MsSql2000Dialect))
146+
{
147+
Assert.Ignore();
148+
}
149+
150+
using (var sqlLog = new SqlLogSpy())
151+
{
152+
var result = await (db.Orders.CountAsync());
153+
Assert.That(result, Is.EqualTo(830));
154+
155+
var log = sqlLog.GetWholeLog();
156+
Assert.That(log, Does.Not.Contain("cast("));
157+
}
158+
}
113159
}
114160
}

src/NHibernate.Test/Async/QueryTest/CountFixture.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
using NHibernate.Cfg;
1313
using NHibernate.Dialect.Function;
1414
using NHibernate.DomainModel;
15+
using NHibernate.Engine;
16+
using NHibernate.Type;
1517
using NUnit.Framework;
1618
using Environment=NHibernate.Cfg.Environment;
1719

@@ -55,4 +57,4 @@ public async Task OverriddenAsync()
5557
await (sf.CloseAsync());
5658
}
5759
}
58-
}
60+
}

src/NHibernate.Test/Hql/SimpleFunctionsTest.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ public void ClassicSum()
135135
Assert.Throws<QueryException>(() => csf.Render(args, factoryImpl));
136136
}
137137

138-
[Test]
138+
// Since v5.3
139+
[Test, Obsolete]
139140
public void ClassicCount()
140141
{
141142
//ANSI-SQL92 definition

src/NHibernate.Test/Linq/ByMethod/CountTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Linq;
33
using NHibernate.Cfg;
4+
using NHibernate.Dialect;
45
using NUnit.Framework;
56

67
namespace NHibernate.Test.Linq.ByMethod
@@ -98,5 +99,50 @@ into temp
9899

99100
Assert.That(result.Count, Is.EqualTo(77));
100101
}
102+
103+
[Test]
104+
public void CheckSqlFunctionNameLongCount()
105+
{
106+
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
107+
using (var sqlLog = new SqlLogSpy())
108+
{
109+
var result = db.Orders.LongCount();
110+
Assert.That(result, Is.EqualTo(830));
111+
112+
var log = sqlLog.GetWholeLog();
113+
Assert.That(log, Does.Contain($"{name}("));
114+
}
115+
}
116+
117+
[Test]
118+
public void CheckSqlFunctionNameForCount()
119+
{
120+
using (var sqlLog = new SqlLogSpy())
121+
{
122+
var result = db.Orders.Count();
123+
Assert.That(result, Is.EqualTo(830));
124+
125+
var log = sqlLog.GetWholeLog();
126+
Assert.That(log, Does.Contain("count("));
127+
}
128+
}
129+
130+
[Test]
131+
public void CheckMssqlCountCast()
132+
{
133+
if (!(Dialect is MsSql2000Dialect))
134+
{
135+
Assert.Ignore();
136+
}
137+
138+
using (var sqlLog = new SqlLogSpy())
139+
{
140+
var result = db.Orders.Count();
141+
Assert.That(result, Is.EqualTo(830));
142+
143+
var log = sqlLog.GetWholeLog();
144+
Assert.That(log, Does.Not.Contain("cast("));
145+
}
146+
}
101147
}
102148
}

src/NHibernate.Test/QueryTest/CountFixture.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
using NHibernate.Cfg;
33
using NHibernate.Dialect.Function;
44
using NHibernate.DomainModel;
5+
using NHibernate.Engine;
6+
using NHibernate.Type;
57
using NUnit.Framework;
68
using Environment=NHibernate.Cfg.Environment;
79

@@ -44,4 +46,17 @@ public void Overridden()
4446
sf.Close();
4547
}
4648
}
47-
}
49+
50+
[Serializable]
51+
internal class ClassicCountFunction : ClassicAggregateFunction
52+
{
53+
public ClassicCountFunction() : base("count", true)
54+
{
55+
}
56+
57+
public override IType ReturnType(IType columnType, IMapping mapping)
58+
{
59+
return NHibernateUtil.Int32;
60+
}
61+
}
62+
}

src/NHibernate/Dialect/Dialect.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public abstract partial class Dialect
5555
static Dialect()
5656
{
5757
StandardAggregateFunctions["count"] = new CountQueryFunctionInfo();
58+
StandardAggregateFunctions["count_big"] = new CountQueryFunctionInfo();
5859
StandardAggregateFunctions["avg"] = new AvgQueryFunctionInfo();
5960
StandardAggregateFunctions["max"] = new ClassicAggregateFunction("max", false);
6061
StandardAggregateFunctions["min"] = new ClassicAggregateFunction("min", false);

src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace NHibernate.Dialect.Function
1010
{
1111
[Serializable]
12-
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar
12+
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLAggregateFunction
1313
{
1414
private IType returnType = null;
1515
private readonly string name;
@@ -110,5 +110,18 @@ bool IFunctionGrammar.IsKnownArgument(string token)
110110
}
111111

112112
#endregion
113+
114+
#region ISQLAggregateFunction Members
115+
116+
/// <inheritdoc />
117+
public string FunctionName => name;
118+
119+
/// <inheritdoc />
120+
public virtual IType GetActualReturnType(IType argumentType, IMapping mapping)
121+
{
122+
return ReturnType(argumentType, mapping);
123+
}
124+
125+
#endregion
113126
}
114127
}

src/NHibernate/Dialect/Function/ClassicCountFunction.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ namespace NHibernate.Dialect.Function
77
/// <summary>
88
/// Classic COUNT sqlfunction that return types as it was done in Hibernate 3.1
99
/// </summary>
10+
// Since v5.3
11+
[Obsolete("This class has no more usages in NHibernate and will be removed in a future version.")]
1012
[Serializable]
1113
public class ClassicCountFunction : ClassicAggregateFunction
1214
{
@@ -19,4 +21,4 @@ public override IType ReturnType(IType columnType, IMapping mapping)
1921
return NHibernateUtil.Int32;
2022
}
2123
}
22-
}
24+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using NHibernate.Engine;
2+
using NHibernate.Type;
3+
4+
namespace NHibernate.Dialect.Function
5+
{
6+
/// <inheritdoc />
7+
internal interface ISQLAggregateFunction : ISQLFunction
8+
{
9+
/// <summary>
10+
/// The name of the aggregate function.
11+
/// </summary>
12+
string FunctionName { get; }
13+
14+
/// <summary>
15+
/// Get the type that will be effectively returned by the underlying database.
16+
/// </summary>
17+
/// <param name="argumentType">The type of the first argument</param>
18+
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
19+
/// <returns></returns>
20+
IType GetActualReturnType(IType argumentType, IMapping mapping);
21+
}
22+
}

src/NHibernate/Dialect/MsSql2000Dialect.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ protected virtual void RegisterKeywords()
286286

287287
protected virtual void RegisterFunctions()
288288
{
289-
RegisterFunction("count", new CountBigQueryFunction());
289+
RegisterFunction("count", new CountQueryFunction());
290+
RegisterFunction("count_big", new CountBigQueryFunction());
290291

291292
RegisterFunction("abs", new StandardSQLFunction("abs"));
292293
RegisterFunction("absval", new StandardSQLFunction("absval"));
@@ -706,6 +707,16 @@ public override IType ReturnType(IType columnType, IMapping mapping)
706707
}
707708
}
708709

710+
[Serializable]
711+
private class CountQueryFunction : CountQueryFunctionInfo
712+
{
713+
/// <inheritdoc />
714+
public override IType GetActualReturnType(IType columnType, IMapping mapping)
715+
{
716+
return NHibernateUtil.Int32;
717+
}
718+
}
719+
709720
public override bool SupportsCircularCascadeDeleteConstraints
710721
{
711722
get

src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ private void EndFunctionTemplate(IASTNode m)
310310
}
311311
}
312312

313+
private void OutAggregateFunctionName(IASTNode m)
314+
{
315+
var aggregateNode = (AggregateNode) m;
316+
var template = aggregateNode.SqlFunction;
317+
Out(template == null ? aggregateNode.Text : template.FunctionName);
318+
}
319+
313320
private void CommaBetweenParameters(String comma)
314321
{
315322
writer.CommaBetweenParameters(comma);

src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ selectExpr
150150
;
151151
152152
count
153-
: ^(COUNT { Out("count("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
153+
: ^(c=COUNT { OutAggregateFunctionName(c); Out("("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
154154
;
155155
156156
distinctOrAll
@@ -344,7 +344,7 @@ caseExpr
344344
;
345345
346346
aggregate
347-
: ^(a=AGGREGATE { Out(a); Out("("); } expr { Out(")"); } )
347+
: ^(a=AGGREGATE { OutAggregateFunctionName(a); Out("("); } expr { Out(")"); } )
348348
;
349349
350350

src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using Antlr.Runtime;
3+
using NHibernate.Dialect.Function;
34
using NHibernate.Type;
45
using NHibernate.Hql.Ast.ANTLR.Util;
56

@@ -19,6 +20,8 @@ public AggregateNode(IToken token)
1920
{
2021
}
2122

23+
internal ISQLAggregateFunction SqlFunction => SessionFactoryHelper.FindSQLFunction(Text) as ISQLAggregateFunction;
24+
2225
public override IType DataType
2326
{
2427
get

src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Antlr.Runtime;
2+
using NHibernate.Dialect.Function;
23
using NHibernate.Hql.Ast.ANTLR.Util;
34
using NHibernate.Type;
45

@@ -9,7 +10,7 @@ namespace NHibernate.Hql.Ast.ANTLR.Tree
910
/// Author: josh
1011
/// Ported by: Steve Strong
1112
/// </summary>
12-
class CountNode : AbstractSelectExpression, ISelectExpression
13+
class CountNode : AggregateNode, ISelectExpression
1314
{
1415
public CountNode(IToken token) : base(token)
1516
{
@@ -26,9 +27,5 @@ public override IType DataType
2627
base.DataType = value;
2728
}
2829
}
29-
public override void SetScalarColumnText(int i)
30-
{
31-
ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i);
32-
}
3330
}
3431
}

src/NHibernate/Hql/Ast/HqlTreeBuilder.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ public HqlCount Count(HqlExpression child)
307307
return new HqlCount(_factory, child);
308308
}
309309

310+
public HqlCountBig CountBig(HqlExpression child)
311+
{
312+
return new HqlCountBig(_factory, child);
313+
}
314+
310315
public HqlRowStar RowStar()
311316
{
312317
return new HqlRowStar(_factory);

src/NHibernate/Hql/Ast/HqlTreeNode.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,19 @@ public HqlCount(IASTFactory factory, HqlExpression child)
697697
}
698698
}
699699

700+
public class HqlCountBig : HqlExpression
701+
{
702+
public HqlCountBig(IASTFactory factory)
703+
: base(HqlSqlWalker.COUNT, "count_big", factory)
704+
{
705+
}
706+
707+
public HqlCountBig(IASTFactory factory, HqlExpression child)
708+
: base(HqlSqlWalker.COUNT, "count_big", factory, child)
709+
{
710+
}
711+
}
712+
700713
public class HqlAs : HqlExpression
701714
{
702715
public HqlAs(IASTFactory factory, HqlExpression expression, System.Type type)

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,16 @@ protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)
255255

256256
protected HqlTreeNode VisitNhCount(NhCountExpression expression)
257257
{
258-
return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type);
258+
if (expression is NhLongCountExpression)
259+
{
260+
return IsCastRequired(expression.Type, "count_big", out _)
261+
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression()), expression.Type)
262+
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression()), expression.Type);
263+
}
264+
265+
return IsCastRequired(expression.Type, "count", out _)
266+
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type)
267+
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type);
259268
}
260269

261270
protected HqlTreeNode VisitNhMin(NhMinExpression expression)

0 commit comments

Comments
 (0)