Skip to content

Commit 86f6cd0

Browse files
committed
Reduce cast usage for COUNT aggregate and add support for Mssql count_big
1 parent f8dd4ee commit 86f6cd0

File tree

17 files changed

+211
-16
lines changed

17 files changed

+211
-16
lines changed

src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using System;
1212
using System.Linq;
1313
using NHibernate.Cfg;
14+
using NHibernate.Dialect;
1415
using NUnit.Framework;
1516
using NHibernate.Linq;
1617

@@ -110,5 +111,50 @@ into temp
110111

111112
Assert.That(result.Count, Is.EqualTo(77));
112113
}
114+
115+
[Test]
116+
public async Task CheckSqlFunctionNameLongCountAsync()
117+
{
118+
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
119+
using (var sqlLog = new SqlLogSpy())
120+
{
121+
var result = await (db.Orders.LongCountAsync());
122+
Assert.That(result, Is.EqualTo(830));
123+
124+
var log = sqlLog.GetWholeLog();
125+
Assert.That(log, Does.Contain($"{name}("));
126+
}
127+
}
128+
129+
[Test]
130+
public async Task CheckSqlFunctionNameForCountAsync()
131+
{
132+
using (var sqlLog = new SqlLogSpy())
133+
{
134+
var result = await (db.Orders.CountAsync());
135+
Assert.That(result, Is.EqualTo(830));
136+
137+
var log = sqlLog.GetWholeLog();
138+
Assert.That(log, Does.Contain("count("));
139+
}
140+
}
141+
142+
[Test]
143+
public async Task CheckMssqlCountCastAsync()
144+
{
145+
if (!(Dialect is MsSql2000Dialect))
146+
{
147+
Assert.Ignore();
148+
}
149+
150+
using (var sqlLog = new SqlLogSpy())
151+
{
152+
var result = await (db.Orders.CountAsync());
153+
Assert.That(result, Is.EqualTo(830));
154+
155+
var log = sqlLog.GetWholeLog();
156+
Assert.That(log, Does.Not.Contain("cast("));
157+
}
158+
}
113159
}
114160
}

src/NHibernate.Test/Async/QueryTest/CountFixture.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
using NHibernate.Cfg;
1313
using NHibernate.Dialect.Function;
1414
using NHibernate.DomainModel;
15+
using NHibernate.Engine;
16+
using NHibernate.Type;
1517
using NUnit.Framework;
1618
using Environment=NHibernate.Cfg.Environment;
1719

@@ -55,4 +57,4 @@ public async Task OverriddenAsync()
5557
await (sf.CloseAsync());
5658
}
5759
}
58-
}
60+
}

src/NHibernate.Test/Hql/SimpleFunctionsTest.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ public void ClassicSum()
167167
}
168168
}
169169

170-
[Test]
170+
// Since v5.3
171+
[Test, Obsolete]
171172
public void ClassicCount()
172173
{
173174
//ANSI-SQL92 definition

src/NHibernate.Test/Linq/ByMethod/CountTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Linq;
33
using NHibernate.Cfg;
4+
using NHibernate.Dialect;
45
using NUnit.Framework;
56

67
namespace NHibernate.Test.Linq.ByMethod
@@ -98,5 +99,50 @@ into temp
9899

99100
Assert.That(result.Count, Is.EqualTo(77));
100101
}
102+
103+
[Test]
104+
public void CheckSqlFunctionNameLongCount()
105+
{
106+
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
107+
using (var sqlLog = new SqlLogSpy())
108+
{
109+
var result = db.Orders.LongCount();
110+
Assert.That(result, Is.EqualTo(830));
111+
112+
var log = sqlLog.GetWholeLog();
113+
Assert.That(log, Does.Contain($"{name}("));
114+
}
115+
}
116+
117+
[Test]
118+
public void CheckSqlFunctionNameForCount()
119+
{
120+
using (var sqlLog = new SqlLogSpy())
121+
{
122+
var result = db.Orders.Count();
123+
Assert.That(result, Is.EqualTo(830));
124+
125+
var log = sqlLog.GetWholeLog();
126+
Assert.That(log, Does.Contain("count("));
127+
}
128+
}
129+
130+
[Test]
131+
public void CheckMssqlCountCast()
132+
{
133+
if (!(Dialect is MsSql2000Dialect))
134+
{
135+
Assert.Ignore();
136+
}
137+
138+
using (var sqlLog = new SqlLogSpy())
139+
{
140+
var result = db.Orders.Count();
141+
Assert.That(result, Is.EqualTo(830));
142+
143+
var log = sqlLog.GetWholeLog();
144+
Assert.That(log, Does.Not.Contain("cast("));
145+
}
146+
}
101147
}
102148
}

src/NHibernate.Test/QueryTest/CountFixture.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
using NHibernate.Cfg;
33
using NHibernate.Dialect.Function;
44
using NHibernate.DomainModel;
5+
using NHibernate.Engine;
6+
using NHibernate.Type;
57
using NUnit.Framework;
68
using Environment=NHibernate.Cfg.Environment;
79

@@ -44,4 +46,17 @@ public void Overridden()
4446
sf.Close();
4547
}
4648
}
47-
}
49+
50+
[Serializable]
51+
internal class ClassicCountFunction : ClassicAggregateFunction
52+
{
53+
public ClassicCountFunction() : base("count", true)
54+
{
55+
}
56+
57+
public override IType ReturnType(IType columnType, IMapping mapping)
58+
{
59+
return NHibernateUtil.Int32;
60+
}
61+
}
62+
}

src/NHibernate/Dialect/Dialect.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public abstract partial class Dialect
5555
static Dialect()
5656
{
5757
StandardAggregateFunctions["count"] = new CountQueryFunctionInfo();
58+
StandardAggregateFunctions["count_big"] = new CountQueryFunctionInfo();
5859
StandardAggregateFunctions["avg"] = new AvgQueryFunctionInfo();
5960
StandardAggregateFunctions["max"] = new ClassicAggregateFunction("max", false);
6061
StandardAggregateFunctions["min"] = new ClassicAggregateFunction("min", false);

src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace NHibernate.Dialect.Function
1010
{
1111
[Serializable]
12-
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar
12+
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLAggregateFunction
1313
{
1414
private IType returnType = null;
1515
private readonly string name;
@@ -111,5 +111,18 @@ bool IFunctionGrammar.IsKnownArgument(string token)
111111
}
112112

113113
#endregion
114+
115+
#region ISQLAggregateFunction Members
116+
117+
/// <inheritdoc />
118+
public string FunctionName => name;
119+
120+
/// <inheritdoc />
121+
public virtual IType GetActualReturnType(IType argumentType, IMapping mapping)
122+
{
123+
return ReturnType(argumentType, mapping);
124+
}
125+
126+
#endregion
114127
}
115128
}

src/NHibernate/Dialect/Function/ClassicCountFunction.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ namespace NHibernate.Dialect.Function
77
/// <summary>
88
/// Classic COUNT sqlfunction that return types as it was done in Hibernate 3.1
99
/// </summary>
10+
// Since v5.3
11+
[Obsolete("This class has no more usages in NHibernate and will be removed in a future version.")]
1012
[Serializable]
1113
public class ClassicCountFunction : ClassicAggregateFunction
1214
{
@@ -19,4 +21,4 @@ public override IType ReturnType(IType columnType, IMapping mapping)
1921
return NHibernateUtil.Int32;
2022
}
2123
}
22-
}
24+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using NHibernate.Engine;
2+
using NHibernate.Type;
3+
4+
namespace NHibernate.Dialect.Function
5+
{
6+
/// <inheritdoc />
7+
internal interface ISQLAggregateFunction : ISQLFunction
8+
{
9+
/// <summary>
10+
/// The name of the aggregate function.
11+
/// </summary>
12+
string FunctionName { get; }
13+
14+
/// <summary>
15+
/// Get the type that will be effectively returned by the underlying database.
16+
/// </summary>
17+
/// <param name="argumentType">The type of the first argument</param>
18+
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
19+
/// <returns></returns>
20+
IType GetActualReturnType(IType argumentType, IMapping mapping);
21+
}
22+
}

src/NHibernate/Dialect/MsSql2000Dialect.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ protected virtual void RegisterKeywords()
286286

287287
protected virtual void RegisterFunctions()
288288
{
289-
RegisterFunction("count", new CountBigQueryFunction());
289+
RegisterFunction("count", new CountQueryFunction());
290+
RegisterFunction("count_big", new CountBigQueryFunction());
290291

291292
RegisterFunction("abs", new StandardSQLFunction("abs"));
292293
RegisterFunction("absval", new StandardSQLFunction("absval"));
@@ -705,6 +706,16 @@ public override IType ReturnType(IType columnType, IMapping mapping)
705706
}
706707
}
707708

709+
[Serializable]
710+
private class CountQueryFunction : CountQueryFunctionInfo
711+
{
712+
/// <inheritdoc />
713+
public override IType GetActualReturnType(IType columnType, IMapping mapping)
714+
{
715+
return NHibernateUtil.Int32;
716+
}
717+
}
718+
708719
public override bool SupportsCircularCascadeDeleteConstraints
709720
{
710721
get

src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,13 @@ private void EndFunctionTemplate(IASTNode m)
303303
}
304304
}
305305

306+
private void OutAggregateFunctionName(IASTNode m)
307+
{
308+
var aggregateNode = (AggregateNode) m;
309+
var template = aggregateNode.SqlFunction;
310+
Out(template == null ? aggregateNode.Text : template.FunctionName);
311+
}
312+
306313
private void CommaBetweenParameters(String comma)
307314
{
308315
writer.CommaBetweenParameters(comma);

src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ selectExpr
150150
;
151151
152152
count
153-
: ^(COUNT { Out("count("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
153+
: ^(c=COUNT { OutAggregateFunctionName(c); Out("("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
154154
;
155155
156156
distinctOrAll
@@ -343,7 +343,7 @@ caseExpr
343343
;
344344
345345
aggregate
346-
: ^(a=AGGREGATE { Out(a); Out("("); } expr { Out(")"); } )
346+
: ^(a=AGGREGATE { OutAggregateFunctionName(a); Out("("); } expr { Out(")"); } )
347347
;
348348
349349

src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using Antlr.Runtime;
3+
using NHibernate.Dialect.Function;
34
using NHibernate.Type;
45
using NHibernate.Hql.Ast.ANTLR.Util;
56

@@ -19,6 +20,8 @@ public AggregateNode(IToken token)
1920
{
2021
}
2122

23+
internal ISQLAggregateFunction SqlFunction => SessionFactoryHelper.FindSQLFunction(Text) as ISQLAggregateFunction;
24+
2225
public override IType DataType
2326
{
2427
get

src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Antlr.Runtime;
2+
using NHibernate.Dialect.Function;
23
using NHibernate.Hql.Ast.ANTLR.Util;
34
using NHibernate.Type;
45

@@ -9,13 +10,12 @@ namespace NHibernate.Hql.Ast.ANTLR.Tree
910
/// Author: josh
1011
/// Ported by: Steve Strong
1112
/// </summary>
12-
class CountNode : AbstractSelectExpression, ISelectExpression
13+
class CountNode : AggregateNode, ISelectExpression
1314
{
1415
public CountNode(IToken token) : base(token)
1516
{
1617
}
1718

18-
1919
public override IType DataType
2020
{
2121
get
@@ -27,9 +27,5 @@ public override IType DataType
2727
base.DataType = value;
2828
}
2929
}
30-
public override void SetScalarColumnText(int i)
31-
{
32-
ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i);
33-
}
3430
}
3531
}

src/NHibernate/Hql/Ast/HqlTreeBuilder.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ public HqlCount Count(HqlExpression child)
307307
return new HqlCount(_factory, child);
308308
}
309309

310+
public HqlCountBig CountBig(HqlExpression child)
311+
{
312+
return new HqlCountBig(_factory, child);
313+
}
314+
310315
public HqlRowStar RowStar()
311316
{
312317
return new HqlRowStar(_factory);

src/NHibernate/Hql/Ast/HqlTreeNode.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,19 @@ public HqlCount(IASTFactory factory, HqlExpression child)
675675
}
676676
}
677677

678+
public class HqlCountBig : HqlExpression
679+
{
680+
public HqlCountBig(IASTFactory factory)
681+
: base(HqlSqlWalker.COUNT, "count_big", factory)
682+
{
683+
}
684+
685+
public HqlCountBig(IASTFactory factory, HqlExpression child)
686+
: base(HqlSqlWalker.COUNT, "count_big", factory, child)
687+
{
688+
}
689+
}
690+
678691
public class HqlAs : HqlExpression
679692
{
680693
public HqlAs(IASTFactory factory, HqlExpression expression, System.Type type)

0 commit comments

Comments
 (0)