Skip to content

Commit c6df511

Browse files
committed
Reduce cast usage for SUM aggregate function
1 parent a27b10f commit c6df511

File tree

2 files changed

+189
-2
lines changed

2 files changed

+189
-2
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
using System;
2+
using System.Linq;
3+
using NHibernate.Cfg.MappingSchema;
4+
using NHibernate.Mapping.ByCode;
5+
using NUnit.Framework;
6+
7+
namespace NHibernate.Test.NHSpecificTest.GH2029
8+
{
9+
public class TestClass
10+
{
11+
public virtual int Id { get; set; }
12+
public virtual int? NullableInt32Prop { get; set; }
13+
public virtual int Int32Prop { get; set; }
14+
public virtual long? NullableInt64Prop { get; set; }
15+
public virtual long Int64Prop { get; set; }
16+
}
17+
18+
[TestFixture]
19+
public class Fixture : TestCaseMappingByCode
20+
{
21+
protected override HbmMapping GetMappings()
22+
{
23+
var mapper = new ModelMapper();
24+
mapper.Class<TestClass>(rc =>
25+
{
26+
rc.Id(x => x.Id, m => m.Generator(Generators.Native));
27+
rc.Property(x => x.NullableInt32Prop);
28+
rc.Property(x => x.Int32Prop);
29+
rc.Property(x => x.NullableInt64Prop);
30+
rc.Property(x => x.Int64Prop);
31+
});
32+
33+
return mapper.CompileMappingForAllExplicitlyAddedEntities();
34+
}
35+
36+
protected override void OnSetUp()
37+
{
38+
using (var session = OpenSession())
39+
using (var tx = session.BeginTransaction())
40+
{
41+
session.Save(new TestClass
42+
{
43+
Int32Prop = int.MaxValue,
44+
NullableInt32Prop = int.MaxValue,
45+
Int64Prop = int.MaxValue,
46+
NullableInt64Prop = int.MaxValue
47+
});
48+
session.Save(new TestClass
49+
{
50+
Int32Prop = int.MaxValue,
51+
NullableInt32Prop = int.MaxValue,
52+
Int64Prop = int.MaxValue,
53+
NullableInt64Prop = int.MaxValue
54+
});
55+
session.Save(new TestClass
56+
{
57+
Int32Prop = int.MaxValue,
58+
NullableInt32Prop = null,
59+
Int64Prop = int.MaxValue,
60+
NullableInt64Prop = null
61+
});
62+
63+
tx.Commit();
64+
}
65+
}
66+
67+
protected override void OnTearDown()
68+
{
69+
using (var session = OpenSession())
70+
using (var tx = session.BeginTransaction())
71+
{
72+
session.CreateQuery("delete from TestClass").ExecuteUpdate();
73+
74+
tx.Commit();
75+
}
76+
}
77+
78+
[Test]
79+
public void NullableIntOverflow()
80+
{
81+
using (var session = OpenSession())
82+
using (session.BeginTransaction())
83+
using (var sqlLog = new SqlLogSpy())
84+
{
85+
var groups = session.Query<TestClass>()
86+
.GroupBy(i => 1)
87+
.Select(g => new
88+
{
89+
s = g.Sum(i => (long) i.NullableInt32Prop)
90+
})
91+
.ToArray();
92+
93+
Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1));
94+
Assert.That(groups, Has.Length.EqualTo(1));
95+
Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2));
96+
}
97+
}
98+
99+
[Test]
100+
public void IntOverflow()
101+
{
102+
using (var session = OpenSession())
103+
using (session.BeginTransaction())
104+
using (var sqlLog = new SqlLogSpy())
105+
{
106+
var groups = session.Query<TestClass>()
107+
.GroupBy(i => 1)
108+
.Select(g => new
109+
{
110+
s = g.Sum(i => (long) i.Int32Prop)
111+
})
112+
.ToArray();
113+
114+
Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "cast"), Is.EqualTo(1));
115+
Assert.That(groups, Has.Length.EqualTo(1));
116+
Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3));
117+
}
118+
}
119+
120+
[Test]
121+
public void NullableInt64NoCast()
122+
{
123+
using (var session = OpenSession())
124+
using (session.BeginTransaction())
125+
using (var sqlLog = new SqlLogSpy())
126+
{
127+
var groups = session.Query<TestClass>()
128+
.GroupBy(i => 1)
129+
.Select(g => new {
130+
s = g.Sum(i => i.NullableInt64Prop)
131+
})
132+
.ToArray();
133+
134+
Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast"));
135+
Assert.That(groups, Has.Length.EqualTo(1));
136+
Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 2));
137+
}
138+
}
139+
140+
[Test]
141+
public void Int64NoCast()
142+
{
143+
using (var session = OpenSession())
144+
using (session.BeginTransaction())
145+
using (var sqlLog = new SqlLogSpy())
146+
{
147+
var groups = session.Query<TestClass>()
148+
.GroupBy(i => 1)
149+
.Select(g => new {
150+
s = g.Sum(i => i.Int64Prop)
151+
})
152+
.ToArray();
153+
154+
Assert.That(sqlLog.GetWholeLog(), Does.Not.Contains("cast"));
155+
Assert.That(groups, Has.Length.EqualTo(1));
156+
Assert.That(groups[0].s, Is.EqualTo((long) int.MaxValue * 3));
157+
}
158+
}
159+
160+
private int FindAllOccurrences(string source, string substring)
161+
{
162+
if (source == null)
163+
{
164+
return 0;
165+
}
166+
int n = 0, count = 0;
167+
while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1)
168+
{
169+
n += substring.Length;
170+
++count;
171+
}
172+
return count;
173+
}
174+
}
175+
}

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
using System.Runtime.CompilerServices;
66
using NHibernate.Engine.Query;
77
using NHibernate.Hql.Ast;
8+
using NHibernate.Hql.Ast.ANTLR;
89
using NHibernate.Linq.Expressions;
910
using NHibernate.Linq.Functions;
1011
using NHibernate.Param;
12+
using NHibernate.Type;
1113
using NHibernate.Util;
1214
using Remotion.Linq.Clauses.Expressions;
1315

@@ -263,6 +265,14 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
263265

264266
protected HqlTreeNode VisitNhSum(NhSumExpression expression)
265267
{
268+
var type = expression.Type.UnwrapIfNullable();
269+
var nhType = TypeFactory.GetDefaultTypeFor(type);
270+
if (nhType != null && _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction("sum")
271+
?.ReturnType(nhType, _parameters.SessionFactory)?.ReturnedClass == type)
272+
{
273+
return _hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression());
274+
}
275+
266276
return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type);
267277
}
268278

@@ -477,8 +487,10 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
477487
case ExpressionType.Convert:
478488
case ExpressionType.ConvertChecked:
479489
case ExpressionType.TypeAs:
480-
if ((expression.Operand.Type.IsPrimitive || expression.Operand.Type == typeof(Decimal)) &&
481-
(expression.Type.IsPrimitive || expression.Type == typeof(Decimal)))
490+
var operandType = expression.Operand.Type.UnwrapIfNullable();
491+
if ((operandType.IsPrimitive || operandType == typeof(decimal)) &&
492+
(expression.Type.IsPrimitive || expression.Type == typeof(decimal)) &&
493+
expression.Type != operandType)
482494
{
483495
return _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type);
484496
}

0 commit comments

Comments
 (0)