Skip to content

Add an option to register a custom pre-transformer for a Linq query #2411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 19, 2020
13 changes: 13 additions & 0 deletions doc/reference/modules/configuration.xml
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,19 @@ var session = sessions.OpenSession(conn);
</para>
</entry>
</row>
<row>
<entry>
<literal>query.pre_transformer_registrar</literal>
</entry>
<entry>
The class name of the LINQ query pre-transformer registrar, implementing
<literal>IExpressionTransformerRegistrar</literal>. Defaults to <literal>null</literal> (no registrar).
<para>
<emphasis role="strong">eg.</emphasis>
<literal>classname.of.ExpressionTransformerRegistrar, assembly</literal>
</para>
</entry>
</row>
<row>
<entry>
<literal>linqtohql.generatorsregistry</literal>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by AsyncGenerator.
//
// Changes to this file may cause incorrect behavior and will be lost if
// the code is regenerated.
// </auto-generated>
//------------------------------------------------------------------------------


using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using NHibernate.Linq;
using NHibernate.Linq.Visitors;
using NHibernate.Util;
using NUnit.Framework;
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;

namespace NHibernate.Test.Linq
{
using System.Threading.Tasks;
[TestFixture]
public class CustomPreTransformRegistrarTestsAsync : LinqTestCase
{
protected override void Configure(Cfg.Configuration configuration)
{
configuration.Properties[Cfg.Environment.PreTransformerRegistrar] = typeof(PreTransformerRegistrar).AssemblyQualifiedName;
}

[Test]
public async Task RewriteLikeAsync()
{
// This example shows how to use the pre-transformer registrar to rewrite the
// query so that StartsWith, EndsWith and Contains methods will generate the same sql.
var queryPlanCache = GetQueryPlanCache();
queryPlanCache.Clear();
await (db.Customers.Where(o => o.ContactName.StartsWith("A")).ToListAsync());
await (db.Customers.Where(o => o.ContactName.EndsWith("A")).ToListAsync());
await (db.Customers.Where(o => o.ContactName.Contains("A")).ToListAsync());

Assert.That(queryPlanCache.Count, Is.EqualTo(1));
}

[Serializable]
public class PreTransformerRegistrar : IExpressionTransformerRegistrar
{
public void Register(ExpressionTransformerRegistry expressionTransformerRegistry)
{
expressionTransformerRegistry.Register(new LikeTransformer());
}
}

private class LikeTransformer : IExpressionTransformer<MethodCallExpression>
{
private static readonly MethodInfo Like = ReflectHelper.GetMethodDefinition(() => SqlMethods.Like(null, null));
private static readonly MethodInfo EndsWith = ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null));
private static readonly MethodInfo StartsWith = ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null));
private static readonly MethodInfo Contains = ReflectHelper.GetMethodDefinition<string>(x => x.Contains(null));
private static readonly Dictionary<MethodInfo, Func<object, string>> ValueTransformers =
new Dictionary<MethodInfo, Func<object, string>>
{
{StartsWith, s => $"{s}%"},
{EndsWith, s => $"%{s}"},
{Contains, s => $"%{s}%"},
};

public Expression Transform(MethodCallExpression expression)
{
if (ValueTransformers.TryGetValue(expression.Method, out var valueTransformer) &&
expression.Arguments[0] is ConstantExpression constantExpression)
{
return Expression.Call(
Like,
expression.Object,
Expression.Constant(valueTransformer(constantExpression.Value))
);
}

return expression;
}

public ExpressionType[] SupportedExpressionTypes { get; } = {ExpressionType.Call};
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using NHibernate.Linq;
using NHibernate.Linq.Visitors;
using NHibernate.Util;
using NUnit.Framework;
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;

namespace NHibernate.Test.Linq
{
[TestFixture]
public class CustomPreTransformRegistrarTests : LinqTestCase
{
protected override void Configure(Cfg.Configuration configuration)
{
configuration.Properties[Cfg.Environment.PreTransformerRegistrar] = typeof(PreTransformerRegistrar).AssemblyQualifiedName;
}

[Test]
public void RewriteLike()
{
// This example shows how to use the pre-transformer registrar to rewrite the
// query so that StartsWith, EndsWith and Contains methods will generate the same sql.
var queryPlanCache = GetQueryPlanCache();
queryPlanCache.Clear();
db.Customers.Where(o => o.ContactName.StartsWith("A")).ToList();
db.Customers.Where(o => o.ContactName.EndsWith("A")).ToList();
db.Customers.Where(o => o.ContactName.Contains("A")).ToList();

Assert.That(queryPlanCache.Count, Is.EqualTo(1));
}

[Serializable]
public class PreTransformerRegistrar : IExpressionTransformerRegistrar
{
public void Register(ExpressionTransformerRegistry expressionTransformerRegistry)
{
expressionTransformerRegistry.Register(new LikeTransformer());
}
}

private class LikeTransformer : IExpressionTransformer<MethodCallExpression>
{
private static readonly MethodInfo Like = ReflectHelper.GetMethodDefinition(() => SqlMethods.Like(null, null));
private static readonly MethodInfo EndsWith = ReflectHelper.GetMethodDefinition<string>(x => x.EndsWith(null));
private static readonly MethodInfo StartsWith = ReflectHelper.GetMethodDefinition<string>(x => x.StartsWith(null));
private static readonly MethodInfo Contains = ReflectHelper.GetMethodDefinition<string>(x => x.Contains(null));
private static readonly Dictionary<MethodInfo, Func<object, string>> ValueTransformers =
new Dictionary<MethodInfo, Func<object, string>>
{
{StartsWith, s => $"{s}%"},
{EndsWith, s => $"%{s}"},
{Contains, s => $"%{s}%"},
};

public Expression Transform(MethodCallExpression expression)
{
if (ValueTransformers.TryGetValue(expression.Method, out var valueTransformer) &&
expression.Arguments[0] is ConstantExpression constantExpression)
{
return Expression.Call(
Like,
expression.Object,
Expression.Constant(valueTransformer(constantExpression.Value))
);
}

return expression;
}

public ExpressionType[] SupportedExpressionTypes { get; } = {ExpressionType.Call};
}
}
}
21 changes: 15 additions & 6 deletions src/NHibernate.Test/TestCase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ public abstract class TestCase
private SchemaExport _schemaExport;

private static readonly ILog log = LogManager.GetLogger(typeof(TestCase));
private static readonly FieldInfo PlanCacheField;

static TestCase()
{
PlanCacheField = typeof(QueryPlanCache)
.GetField("planCache", BindingFlags.NonPublic | BindingFlags.Instance)
?? throw new InvalidOperationException(
"planCache field does not exist in QueryPlanCache.");
}

protected Dialect.Dialect Dialect
{
Expand Down Expand Up @@ -488,14 +497,14 @@ protected void AssumeFunctionSupported(string functionName)
$"{dialect} doesn't support {functionName} standard function.");
}

protected void ClearQueryPlanCache()
protected SoftLimitMRUCache GetQueryPlanCache()
{
var planCacheField = typeof(QueryPlanCache)
.GetField("planCache", BindingFlags.NonPublic | BindingFlags.Instance)
?? throw new InvalidOperationException("planCache field does not exist in QueryPlanCache.");
return (SoftLimitMRUCache) PlanCacheField.GetValue(Sfi.QueryPlanCache);
}

var planCache = (SoftLimitMRUCache) planCacheField.GetValue(Sfi.QueryPlanCache);
planCache.Clear();
protected void ClearQueryPlanCache()
{
GetQueryPlanCache().Clear();
}

protected Substitute<Dialect.Dialect> SubstituteDialect()
Expand Down
6 changes: 6 additions & 0 deletions src/NHibernate/Cfg/Environment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using NHibernate.Cfg.ConfigurationSchema;
using NHibernate.Engine;
using NHibernate.Linq;
using NHibernate.Linq.Visitors;
using NHibernate.Util;

namespace NHibernate.Cfg
Expand Down Expand Up @@ -276,6 +277,11 @@ public static string Version

public const string QueryModelRewriterFactory = "query.query_model_rewriter_factory";

/// <summary>
/// The class name of the LINQ query pre-transformer registrar, implementing <see cref="IExpressionTransformerRegistrar"/>.
/// </summary>
public const string PreTransformerRegistrar = "query.pre_transformer_registrar";

/// <summary>
/// Set the default length used in casting when the target type is length bound and
/// does not specify it. <c>4000</c> by default, automatically trimmed down according to dialect type registration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ public void QueryModelRewriterFactory<TFactory>() where TFactory : IQueryModelRe
configuration.SetProperty(Environment.QueryModelRewriterFactory, typeof(TFactory).AssemblyQualifiedName);
}

/// <summary>
/// Set the class of the LINQ query pre-transformer registrar.
/// </summary>
/// <typeparam name="TRegistrar">The class of the LINQ query pre-transformer registrar.</typeparam>
public void PreTransformerRegistrar<TRegistrar>() where TRegistrar : IExpressionTransformerRegistrar
{
configuration.SetProperty(Environment.PreTransformerRegistrar, typeof(TRegistrar).AssemblyQualifiedName);
}

#endregion
}
}
10 changes: 9 additions & 1 deletion src/NHibernate/Cfg/Settings.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Linq.Expressions;
using NHibernate.AdoNet;
using NHibernate.AdoNet.Util;
using NHibernate.Cache;
Expand Down Expand Up @@ -182,7 +183,14 @@ public Settings()
public bool LinqToHqlFallbackOnPreEvaluation { get; internal set; }

public IQueryModelRewriterFactory QueryModelRewriterFactory { get; internal set; }


/// <summary>
/// The pre-transformer registrar used to register custom expression transformers.
/// </summary>
public IExpressionTransformerRegistrar PreTransformerRegistrar { get; internal set; }

internal Func<Expression, Expression> LinqPreTransformer { get; set; }

#endregion

internal string GetFullCacheRegionName(string name)
Expand Down
29 changes: 28 additions & 1 deletion src/NHibernate/Cfg/SettingsFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,14 @@ public Settings BuildSettings(IDictionary<string, string> properties)
// Not ported - JdbcBatchVersionedData

settings.QueryModelRewriterFactory = CreateQueryModelRewriterFactory(properties);

settings.PreTransformerRegistrar = CreatePreTransformerRegistrar(properties);

// Avoid dependency on re-linq assembly when PreTransformerRegistrar is null
if (settings.PreTransformerRegistrar != null)
{
settings.LinqPreTransformer = NhRelinqQueryParser.CreatePreTransformer(settings.PreTransformerRegistrar);
}

// NHibernate-specific:
settings.IsolationLevel = isolation;

Expand Down Expand Up @@ -442,5 +449,25 @@ private static IQueryModelRewriterFactory CreateQueryModelRewriterFactory(IDicti
throw new HibernateException("could not instantiate IQueryModelRewriterFactory: " + className, cnfe);
}
}

private static IExpressionTransformerRegistrar CreatePreTransformerRegistrar(IDictionary<string, string> properties)
{
var className = PropertiesHelper.GetString(Environment.PreTransformerRegistrar, properties, null);
if (className == null)
return null;

log.Info("Pre-transformer registrar: {0}", className);

try
{
return
(IExpressionTransformerRegistrar)
Environment.ObjectsFactory.CreateInstance(ReflectHelper.ClassForName(className));
}
catch (Exception e)
{
throw new HibernateException("could not instantiate IExpressionTransformerRegistrar: " + className, e);
}
}
}
}
20 changes: 12 additions & 8 deletions src/NHibernate/Linq/NhRelinqQueryParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,9 @@ namespace NHibernate.Linq
public static class NhRelinqQueryParser
{
private static readonly QueryParser QueryParser;
private static readonly IExpressionTreeProcessor PreProcessor;

static NhRelinqQueryParser()
{
var preTransformerRegistry = new ExpressionTransformerRegistry();
// NH-3247: must remove .Net compiler char to int conversion before
// parameterization occurs.
preTransformerRegistry.Register(new RemoveCharToIntConversion());
PreProcessor = new TransformingExpressionTreeProcessor(preTransformerRegistry);

var transformerRegistry = ExpressionTransformerRegistry.CreateDefault();
transformerRegistry.Register(new RemoveRedundantCast());
transformerRegistry.Register(new SimplifyCompareTransformer());
Expand Down Expand Up @@ -78,7 +71,7 @@ public static PreTransformationResult PreTransform(Expression expression, PreTra
.EvaluateIndependentSubtrees(expression, parameters);

return new PreTransformationResult(
PreProcessor.Process(partiallyEvaluatedExpression),
parameters.PreTransformer.Invoke(partiallyEvaluatedExpression),
parameters.SessionFactory,
parameters.QueryVariables);
}
Expand All @@ -87,6 +80,17 @@ public static QueryModel Parse(Expression expression)
{
return QueryParser.GetParsedQuery(expression);
}

internal static Func<Expression, Expression> CreatePreTransformer(IExpressionTransformerRegistrar expressionTransformerRegistrar)
{
var preTransformerRegistry = new ExpressionTransformerRegistry();
// NH-3247: must remove .Net compiler char to int conversion before
// parameterization occurs.
preTransformerRegistry.Register(new RemoveCharToIntConversion());
expressionTransformerRegistrar?.Register(preTransformerRegistry);

return new TransformingExpressionTreeProcessor(preTransformerRegistry).Process;
}
}

public class NHibernateNodeTypeProvider : INodeTypeProvider
Expand Down
16 changes: 16 additions & 0 deletions src/NHibernate/Linq/Visitors/IExpressionTransformerRegistrar.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using Remotion.Linq.Parsing.ExpressionVisitors.Transformation;

namespace NHibernate.Linq.Visitors
{
/// <summary>
/// Provides a way to register custom transformers for expressions.
/// </summary>
public interface IExpressionTransformerRegistrar
{
/// <summary>
/// Registers additional transformers on the expression transformer registry.
/// </summary>
/// <param name="expressionTransformerRegistry">The expression transformer registry.</param>
void Register(ExpressionTransformerRegistry expressionTransformerRegistry);
}
}
Loading