Skip to content

Commit 89e9933

Browse files
committed
Rename ISQLFunctionExtended interface to ISQLFunctionExtended
1 parent 0b5826e commit 89e9933

File tree

13 files changed

+131
-87
lines changed

13 files changed

+131
-87
lines changed

src/NHibernate.Test/Async/QueryTest/CountFixture.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
using NHibernate.Cfg;
1313
using NHibernate.Dialect.Function;
1414
using NHibernate.DomainModel;
15-
using NHibernate.Engine;
16-
using NHibernate.Type;
1715
using NUnit.Framework;
1816
using Environment=NHibernate.Cfg.Environment;
1917

@@ -57,4 +55,4 @@ public async Task OverriddenAsync()
5755
await (sf.CloseAsync());
5856
}
5957
}
60-
}
58+
}

src/NHibernate.Test/Hql/SimpleFunctionsTest.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ public void ClassicSum()
135135
Assert.Throws<QueryException>(() => csf.Render(args, factoryImpl));
136136
}
137137

138-
// Since v5.3
139-
[Test, Obsolete]
138+
[Test]
140139
public void ClassicCount()
141140
{
142141
//ANSI-SQL92 definition

src/NHibernate.Test/QueryTest/CountFixture.cs

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
using NHibernate.Cfg;
33
using NHibernate.Dialect.Function;
44
using NHibernate.DomainModel;
5-
using NHibernate.Engine;
6-
using NHibernate.Type;
75
using NUnit.Framework;
86
using Environment=NHibernate.Cfg.Environment;
97

@@ -46,17 +44,4 @@ public void Overridden()
4644
sf.Close();
4745
}
4846
}
49-
50-
[Serializable]
51-
internal class ClassicCountFunction : ClassicAggregateFunction
52-
{
53-
public ClassicCountFunction() : base("count", true)
54-
{
55-
}
56-
57-
public override IType ReturnType(IType columnType, IMapping mapping)
58-
{
59-
return NHibernateUtil.Int32;
60-
}
61-
}
62-
}
47+
}

src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
using System;
22
using System.Collections;
3-
using System.Text;
3+
using System.Collections.Generic;
4+
using System.Linq;
45
using NHibernate.Engine;
56
using NHibernate.SqlCommand;
67
using NHibernate.Type;
7-
using NHibernate.Util;
88

99
namespace NHibernate.Dialect.Function
1010
{
1111
[Serializable]
12-
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLAggregateFunction
12+
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended
1313
{
1414
private IType returnType = null;
1515
private readonly string name;
@@ -45,6 +45,15 @@ public virtual IType ReturnType(IType columnType, IMapping mapping)
4545
return returnType ?? columnType;
4646
}
4747

48+
/// <inheritdoc />
49+
public virtual IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
50+
{
51+
return ReturnType(argumentTypes.FirstOrDefault(), mapping);
52+
}
53+
54+
/// <inheritdoc />
55+
public string FunctionName => name;
56+
4857
public bool HasArguments
4958
{
5059
get { return true; }
@@ -110,18 +119,5 @@ bool IFunctionGrammar.IsKnownArgument(string token)
110119
}
111120

112121
#endregion
113-
114-
#region ISQLAggregateFunction Members
115-
116-
/// <inheritdoc />
117-
public string FunctionName => name;
118-
119-
/// <inheritdoc />
120-
public virtual IType GetActualReturnType(IType argumentType, IMapping mapping)
121-
{
122-
return ReturnType(argumentType, mapping);
123-
}
124-
125-
#endregion
126122
}
127123
}

src/NHibernate/Dialect/Function/ClassicCountFunction.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ namespace NHibernate.Dialect.Function
77
/// <summary>
88
/// Classic COUNT sqlfunction that return types as it was done in Hibernate 3.1
99
/// </summary>
10-
// Since v5.3
11-
[Obsolete("This class has no more usages in NHibernate and will be removed in a future version.")]
1210
[Serializable]
1311
public class ClassicCountFunction : ClassicAggregateFunction
1412
{
@@ -21,4 +19,4 @@ public override IType ReturnType(IType columnType, IMapping mapping)
2119
return NHibernateUtil.Int32;
2220
}
2321
}
24-
}
22+
}

src/NHibernate/Dialect/Function/ISQLAggregateFunction.cs

Lines changed: 0 additions & 22 deletions
This file was deleted.

src/NHibernate/Dialect/Function/ISQLFunction.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System.Collections;
2+
using System.Collections.Generic;
3+
using System.Linq;
24
using NHibernate.Engine;
35
using NHibernate.SqlCommand;
46
using NHibernate.Type;
@@ -41,4 +43,47 @@ public interface ISQLFunction
4143
/// <returns>SQL fragment for the function.</returns>
4244
SqlString Render(IList args, ISessionFactoryImplementor factory);
4345
}
46+
47+
// 6.0 TODO: Remove
48+
internal static class SQLFunctionExtensions
49+
{
50+
/// <summary>
51+
/// Get the type that will be effectively returned by the underlying database.
52+
/// </summary>
53+
/// <param name="sqlFunction">The sql function.</param>
54+
/// <param name="argumentTypes">The types of arguments.</param>
55+
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
56+
/// <param name="throwOnError">Whether to throw when the number of arguments is invalid or they are not supported.</param>
57+
/// <returns>The type returned by the underlying database or <see langword="null"/> when the number of arguments
58+
/// is invalid or they are not supported.</returns>
59+
/// <exception cref="QueryException">When <paramref name="throwOnError"/> is set to <see langword="true"/> and the
60+
/// number of arguments is invalid or they are not supported.</exception>
61+
public static IType GetEffectiveReturnType(
62+
this ISQLFunction sqlFunction,
63+
IEnumerable<IType> argumentTypes,
64+
IMapping mapping,
65+
bool throwOnError)
66+
{
67+
if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction))
68+
{
69+
try
70+
{
71+
#pragma warning disable 618
72+
return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping);
73+
#pragma warning restore 618
74+
}
75+
catch (QueryException)
76+
{
77+
if (throwOnError)
78+
{
79+
throw;
80+
}
81+
82+
return null;
83+
}
84+
}
85+
86+
return extendedSqlFunction.GetEffectiveReturnType(argumentTypes, mapping, throwOnError);
87+
}
88+
}
4489
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System.Collections.Generic;
2+
using NHibernate.Engine;
3+
using NHibernate.Type;
4+
5+
namespace NHibernate.Dialect.Function
6+
{
7+
// 6.0 TODO: Merge into ISQLFunction
8+
internal interface ISQLFunctionExtended : ISQLFunction
9+
{
10+
/// <summary>
11+
/// The function name or <see langword="null"/> when multiple functions/operators/statements are used.
12+
/// </summary>
13+
string FunctionName { get; }
14+
15+
/// <summary>
16+
/// Get the type that will be effectively returned by the underlying database.
17+
/// </summary>
18+
/// <param name="argumentTypes">The types of arguments.</param>
19+
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
20+
/// <param name="throwOnError">Whether to throw when the number of arguments is invalid or they are not supported.</param>
21+
/// <returns>The type returned by the underlying database or <see langword="null"/> when the number of arguments
22+
/// is invalid or they are not supported.</returns>
23+
/// <exception cref="QueryException">When <paramref name="throwOnError"/> is set to <see langword="true"/> and the
24+
/// number of arguments is invalid or they are not supported.</exception>
25+
IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError);
26+
}
27+
}

src/NHibernate/Dialect/MsSql2000Dialect.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ protected class CountBigQueryFunction : ClassicAggregateFunction
701701
{
702702
public CountBigQueryFunction() : base("count_big", true) { }
703703

704-
public override IType ReturnType(IType columnType, IMapping mapping)
704+
public override IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
705705
{
706706
return NHibernateUtil.Int64;
707707
}
@@ -710,8 +710,7 @@ public override IType ReturnType(IType columnType, IMapping mapping)
710710
[Serializable]
711711
private class CountQueryFunction : CountQueryFunctionInfo
712712
{
713-
/// <inheritdoc />
714-
public override IType GetActualReturnType(IType columnType, IMapping mapping)
713+
public override IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
715714
{
716715
return NHibernateUtil.Int32;
717716
}

src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,7 @@ private void EndFunctionTemplate(IASTNode m)
313313
private void OutAggregateFunctionName(IASTNode m)
314314
{
315315
var aggregateNode = (AggregateNode) m;
316-
var template = aggregateNode.SqlFunction;
317-
Out(template == null ? aggregateNode.Text : template.FunctionName);
316+
Out(aggregateNode.SqlFunction?.FunctionName ?? aggregateNode.Text);
318317
}
319318

320319
private void CommaBetweenParameters(String comma)

src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public AggregateNode(IToken token)
2020
{
2121
}
2222

23-
internal ISQLAggregateFunction SqlFunction => SessionFactoryHelper.FindSQLFunction(Text) as ISQLAggregateFunction;
23+
internal ISQLFunctionExtended SqlFunction => SessionFactoryHelper.FindSQLFunction(Text) as ISQLFunctionExtended;
2424

2525
public override IType DataType
2626
{

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Linq;
55
using System.Linq.Expressions;
66
using System.Runtime.CompilerServices;
7+
using NHibernate.Dialect.Function;
78
using NHibernate.Engine.Query;
89
using NHibernate.Hql.Ast;
910
using NHibernate.Hql.Ast.ANTLR;
@@ -255,16 +256,22 @@ protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)
255256

256257
protected HqlTreeNode VisitNhCount(NhCountExpression expression)
257258
{
259+
string functionName;
260+
HqlExpression countHqlExpression;
258261
if (expression is NhLongCountExpression)
259262
{
260-
return IsCastRequired(expression.Type, "count_big", out _)
261-
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression()), expression.Type)
262-
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression()), expression.Type);
263+
functionName = "count_big";
264+
countHqlExpression = _hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression());
265+
}
266+
else
267+
{
268+
functionName = "count";
269+
countHqlExpression = _hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression());
263270
}
264271

265-
return IsCastRequired(expression.Type, "count", out _)
266-
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type)
267-
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type);
272+
return IsCastRequired(functionName, expression.Expression, expression.Type)
273+
? (HqlTreeNode) _hqlTreeBuilder.Cast(countHqlExpression, expression.Type)
274+
: _hqlTreeBuilder.TransparentCast(countHqlExpression, expression.Type);
268275
}
269276

270277
protected HqlTreeNode VisitNhMin(NhMinExpression expression)
@@ -606,7 +613,7 @@ private bool IsCastRequired(Expression expression, System.Type toType, out bool
606613
{
607614
existType = false;
608615
return toType != typeof(object) &&
609-
IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
616+
IsCastRequired(ExpressionsHelper.GetType(_parameters, expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
610617
}
611618

612619
private bool IsCastRequired(IType type, IType toType, out bool existType)
@@ -650,7 +657,7 @@ private bool IsCastRequired(IType type, IType toType, out bool existType)
650657

651658
private bool IsCastRequired(string sqlFunctionName, Expression argumentExpression, System.Type returnType)
652659
{
653-
var argumentType = GetType(argumentExpression);
660+
var argumentType = ExpressionsHelper.GetType(_parameters, argumentExpression);
654661
if (argumentType == null || returnType == typeof(object))
655662
{
656663
return false;
@@ -668,18 +675,8 @@ private bool IsCastRequired(string sqlFunctionName, Expression argumentExpressio
668675
return true; // Fallback to the old behavior
669676
}
670677

671-
var fnReturnType = sqlFunction.ReturnType(argumentType, _parameters.SessionFactory);
678+
var fnReturnType = sqlFunction.GetEffectiveReturnType(new[] {argumentType}, _parameters.SessionFactory, false);
672679
return fnReturnType == null || IsCastRequired(fnReturnType, returnNhType, out _);
673680
}
674-
675-
private IType GetType(Expression expression)
676-
{
677-
// Try to get the mapped type for the member as it may be a non default one
678-
return expression.Type == typeof(object)
679-
? null
680-
: (ExpressionsHelper.TryGetMappedType(_parameters.SessionFactory, expression, out var type, out _, out _, out _)
681-
? type
682-
: TypeFactory.GetDefaultTypeFor(expression.Type));
683-
}
684681
}
685682
}

src/NHibernate/Util/ExpressionsHelper.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ public static MemberInfo DecodeMemberAccessExpression<TEntity, TResult>(Expressi
2828
return ((MemberExpression)expression.Body).Member;
2929
}
3030

31+
/// <summary>
32+
/// Get the mapped type for the given expression.
33+
/// </summary>
34+
/// <param name="parameters">The query parameters.</param>
35+
/// <param name="expression">The expression.</param>
36+
/// <returns>The mapped type of the expression or <see langword="null"/> when the mapped type was not
37+
/// found and the <paramref name="expression"/> type is <see cref="object"/>.</returns>
38+
internal static IType GetType(VisitorParameters parameters, Expression expression)
39+
{
40+
if (expression is ConstantExpression constantExpression &&
41+
parameters.ConstantToParameterMap.TryGetValue(constantExpression, out var param))
42+
{
43+
return param.Type;
44+
}
45+
46+
if (TryGetMappedType(parameters.SessionFactory, expression, out var type, out _, out _, out _))
47+
{
48+
return type;
49+
}
50+
51+
return expression.Type == typeof(object) ? null : TypeFactory.HeuristicType(expression.Type);
52+
}
53+
3154
/// <summary>
3255
/// Try to get the mapped nullability from the given expression.
3356
/// </summary>

0 commit comments

Comments
 (0)