Skip to content

Commit d476c4c

Browse files
committed
Merge branch 'pvginkel-NH-3499'
2 parents d27d8ab + c6f0788 commit d476c4c

10 files changed

+157
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Linq.Expressions;
5+
using System.Text;
6+
using NHibernate.Linq.Visitors;
7+
using NUnit.Framework;
8+
using Remotion.Linq;
9+
using Remotion.Linq.Clauses;
10+
using Remotion.Linq.Parsing;
11+
12+
namespace NHibernate.Test.Linq
13+
{
14+
public class CustomQueryModelRewriterTests : LinqTestCase
15+
{
16+
protected override void Configure(Cfg.Configuration configuration)
17+
{
18+
configuration.Properties[Cfg.Environment.QueryModelRewriterFactory] = typeof(QueryModelRewriterFactory).AssemblyQualifiedName;
19+
}
20+
21+
[Test]
22+
public void RewriteNullComparison()
23+
{
24+
// This example shows how to use the query model rewriter to
25+
// make radical changes to the query. In this case, we rewrite
26+
// a null comparison (which would translate into a IS NULL)
27+
// into a comparison to "Thomas Hardy" (which translates to a = "Thomas Hardy").
28+
29+
var contacts = (from c in db.Customers where c.ContactName == null select c).ToList();
30+
Assert.Greater(contacts.Count, 0);
31+
Assert.IsTrue(contacts.Select(customer => customer.ContactName).All(c => c == "Thomas Hardy"));
32+
}
33+
34+
[Serializable]
35+
public class QueryModelRewriterFactory : IQueryModelRewriterFactory
36+
{
37+
public QueryModelVisitorBase CreateVisitor(VisitorParameters parameters)
38+
{
39+
return new CustomVisitor();
40+
}
41+
}
42+
43+
public class CustomVisitor : QueryModelVisitorBase
44+
{
45+
public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index)
46+
{
47+
whereClause.TransformExpressions(new Visitor().VisitExpression);
48+
}
49+
50+
private class Visitor : ExpressionTreeVisitor
51+
{
52+
protected override Expression VisitBinaryExpression(BinaryExpression expression)
53+
{
54+
if (
55+
expression.NodeType == ExpressionType.Equal ||
56+
expression.NodeType == ExpressionType.NotEqual
57+
)
58+
{
59+
var left = expression.Left;
60+
var right = expression.Right;
61+
bool reverse = false;
62+
63+
if (!(left is ConstantExpression) && right is ConstantExpression)
64+
{
65+
var tmp = left;
66+
left = right;
67+
right = tmp;
68+
reverse = true;
69+
}
70+
71+
var constant = left as ConstantExpression;
72+
73+
if (constant != null && constant.Value == null)
74+
{
75+
left = Expression.Constant("Thomas Hardy");
76+
77+
expression = Expression.MakeBinary(
78+
expression.NodeType,
79+
reverse ? right : left,
80+
reverse ? left : right
81+
);
82+
}
83+
}
84+
85+
return base.VisitBinaryExpression(expression);
86+
}
87+
}
88+
}
89+
}
90+
}

src/NHibernate.Test/NHibernate.Test.csproj

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@
8080
<SpecificVersion>False</SpecificVersion>
8181
<HintPath>..\..\lib\net\nunit.framework.dll</HintPath>
8282
</Reference>
83+
<Reference Include="Remotion.Linq, Version=1.13.171.1, Culture=neutral, PublicKeyToken=fee00910d6e5f53b, processorArchitecture=MSIL">
84+
<SpecificVersion>False</SpecificVersion>
85+
<HintPath>..\..\lib\net\Remotion.Linq.dll</HintPath>
86+
</Reference>
8387
<Reference Include="System" />
8488
<Reference Include="System.configuration" />
8589
<Reference Include="System.Core">
@@ -534,6 +538,7 @@
534538
<Compile Include="Linq\ByMethod\GetValueOrDefaultTests.cs" />
535539
<Compile Include="Linq\CasingTest.cs" />
536540
<Compile Include="Linq\CharComparisonTests.cs" />
541+
<Compile Include="Linq\CustomQueryModelRewriterTests.cs" />
537542
<Compile Include="Linq\DateTimeTests.cs" />
538543
<Compile Include="Linq\ExpressionSessionLeakTest.cs" />
539544
<Compile Include="Linq\LoggingTests.cs" />

src/NHibernate/Cfg/Environment.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ public static string Version
171171
/// <summary> Enable ordering of insert statements for the purpose of more effecient batching.</summary>
172172
public const string OrderInserts = "order_inserts";
173173

174+
public const string QueryModelRewriterFactory = "query.query_model_rewriter_factory";
175+
174176
/// <summary>
175177
/// If this setting is set to false, exceptions in IInterceptor.BeforeTransactionCompletion bubble to the caller of ITransaction.Commit and abort the commit.
176178
/// If this setting is set to true, exceptions in IInterceptor.BeforeTransactionCompletion are ignored and the commit is performed.

src/NHibernate/Cfg/Loquacious/DbIntegrationConfigurationProperties.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using NHibernate.Connection;
44
using NHibernate.Driver;
55
using NHibernate.Exceptions;
6+
using NHibernate.Linq.Visitors;
67
using NHibernate.Transaction;
78

89
namespace NHibernate.Cfg.Loquacious
@@ -123,6 +124,11 @@ public SchemaAutoAction SchemaAction
123124
set { configuration.SetProperty(Environment.Hbm2ddlAuto, value.ToString()); }
124125
}
125126

127+
public void QueryModelRewriterFactory<TFactory>() where TFactory : IQueryModelRewriterFactory
128+
{
129+
configuration.SetProperty(Environment.QueryModelRewriterFactory, typeof(TFactory).AssemblyQualifiedName);
130+
}
131+
126132
#endregion
127133
}
128134
}

src/NHibernate/Cfg/Loquacious/IDbIntegrationConfigurationProperties.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using NHibernate.Connection;
44
using NHibernate.Driver;
55
using NHibernate.Exceptions;
6+
using NHibernate.Linq.Visitors;
67
using NHibernate.Transaction;
78

89
namespace NHibernate.Cfg.Loquacious
@@ -35,5 +36,7 @@ public interface IDbIntegrationConfigurationProperties
3536
byte MaximumDepthOfOuterJoinFetching { set; }
3637

3738
SchemaAutoAction SchemaAction { set; }
39+
40+
void QueryModelRewriterFactory<TFactory>() where TFactory : IQueryModelRewriterFactory;
3841
}
3942
}

src/NHibernate/Cfg/Settings.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using NHibernate.Exceptions;
99
using NHibernate.Hql;
1010
using NHibernate.Linq.Functions;
11+
using NHibernate.Linq.Visitors;
1112
using NHibernate.Transaction;
1213

1314
namespace NHibernate.Cfg
@@ -129,6 +130,8 @@ public Settings()
129130
[Obsolete("This setting is likely to be removed in a future version of NHibernate. The workaround is to catch all exceptions in the IInterceptor implementation.")]
130131
public bool IsInterceptorsBeforeTransactionCompletionIgnoreExceptionsEnabled { get; internal set; }
131132

133+
public IQueryModelRewriterFactory QueryModelRewriterFactory { get; internal set; }
134+
132135
#endregion
133136
}
134137
}

src/NHibernate/Cfg/SettingsFactory.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using NHibernate.Exceptions;
1212
using NHibernate.Hql;
1313
using NHibernate.Linq.Functions;
14+
using NHibernate.Linq.Visitors;
1415
using NHibernate.Transaction;
1516
using NHibernate.Util;
1617

@@ -288,6 +289,8 @@ public Settings BuildSettings(IDictionary<string, string> properties)
288289
settings.IsMinimalPutsEnabled = useMinimalPuts;
289290
// Not ported - JdbcBatchVersionedData
290291

292+
settings.QueryModelRewriterFactory = CreateQueryModelRewriterFactory(properties);
293+
291294
// NHibernate-specific:
292295
settings.IsolationLevel = isolation;
293296

@@ -379,5 +382,26 @@ private static ITransactionFactory CreateTransactionFactory(IDictionary<string,
379382
throw new HibernateException("could not instantiate TransactionFactory: " + className, cnfe);
380383
}
381384
}
385+
386+
private static IQueryModelRewriterFactory CreateQueryModelRewriterFactory(IDictionary<string, string> properties)
387+
{
388+
string className = PropertiesHelper.GetString(Environment.QueryModelRewriterFactory, properties, null);
389+
390+
if (className == null)
391+
return null;
392+
393+
log.Info("Query model rewriter factory factory: " + className);
394+
395+
try
396+
{
397+
return
398+
(IQueryModelRewriterFactory)
399+
Environment.BytecodeProvider.ObjectsFactory.CreateInstance(ReflectHelper.ClassForName(className));
400+
}
401+
catch (Exception cnfe)
402+
{
403+
throw new HibernateException("could not instantiate IQueryModelRewriterFactory: " + className, cnfe);
404+
}
405+
}
382406
}
383407
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Remotion.Linq;
6+
7+
namespace NHibernate.Linq.Visitors
8+
{
9+
public interface IQueryModelRewriterFactory
10+
{
11+
QueryModelVisitorBase CreateVisitor(VisitorParameters parameters);
12+
}
13+
}

src/NHibernate/Linq/Visitors/QueryModelVisitor.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
5757
// Move OrderBy clauses to end
5858
MoveOrderByToEndRewriter.ReWrite(queryModel);
5959

60+
// Give a rewriter provided by the session factory a chance to
61+
// rewrite the query.
62+
var rewriterFactory = parameters.SessionFactory.Settings.QueryModelRewriterFactory;
63+
if (rewriterFactory != null)
64+
{
65+
var customVisitor = rewriterFactory.CreateVisitor(parameters);
66+
if (customVisitor != null)
67+
customVisitor.VisitQueryModel(queryModel);
68+
}
69+
6070
// rewrite any operators that should be applied on the outer query
6171
// by flattening out the sub-queries that they are located in
6272
var result = ResultOperatorRewriter.Rewrite(queryModel);

src/NHibernate/NHibernate.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@
296296
<Compile Include="Linq\NestedSelects\Tuple.cs" />
297297
<Compile Include="Linq\NestedSelects\SelectClauseRewriter.cs" />
298298
<Compile Include="Linq\NestedSelects\ExpressionHolder.cs" />
299+
<Compile Include="Linq\Visitors\IQueryModelRewriterFactory.cs" />
299300
<Compile Include="Linq\Visitors\LeftJoinRewriter.cs" />
300301
<Compile Include="Linq\Functions\CompareGenerator.cs" />
301302
<Compile Include="Linq\ExpressionTransformers\SimplifyCompareTransformer.cs" />

0 commit comments

Comments
 (0)