Skip to content

NH-3659 - Strongly Typed Delete in Linq #370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/NHibernate.Test/Linq/DeleteTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using System.Linq;
using NHibernate.DomainModel.Northwind.Entities;
using NUnit.Framework;
using NHibernate.Linq;

namespace NHibernate.Test.Linq
{
[TestFixture]
public class DeleteTests : LinqTestCase
{
[Test]
public void CanDeleteSimpleExpression()
{
//NH-3659
using (this.session.BeginTransaction())
{
var beforeDeleteCount = this.session.Query<User>().Count(u => u.Id > 0);

var deletedCount = this.session.Delete<User>(u => u.Id > 0);

var afterDeleteCount = this.session.Query<User>().Count(u => u.Id > 0);

Assert.AreEqual(beforeDeleteCount, deletedCount);

Assert.AreEqual(0, afterDeleteCount);
}
}

[Test]
public void CanDeleteComplexExpression()
{
//NH-3659
using (this.session.BeginTransaction())
{
var cities = new string[] { "Paris", "Madrid" };

var beforeDeleteCount = this.session.Query<Customer>().Count(c => c.Orders.Count() == 0 && cities.Contains(c.Address.City));

var deletedCount = this.session.Delete<Customer>(c => c.Orders.Count() == 0 && cities.Contains(c.Address.City));

var afterDeleteCount = this.session.Query<Customer>().Count(c => c.Orders.Count() == 0 && cities.Contains(c.Address.City));

Assert.AreEqual(beforeDeleteCount, deletedCount);

Assert.AreEqual(0, afterDeleteCount);
}
}
}
}
1 change: 1 addition & 0 deletions src/NHibernate.Test/NHibernate.Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@
<Compile Include="Linq\CharComparisonTests.cs" />
<Compile Include="Linq\CustomQueryModelRewriterTests.cs" />
<Compile Include="Linq\DateTimeTests.cs" />
<Compile Include="Linq\DeleteTests.cs" />
<Compile Include="Linq\ExpressionSessionLeakTest.cs" />
<Compile Include="Linq\LoggingTests.cs" />
<Compile Include="Linq\QueryTimeoutTests.cs" />
Expand Down
98 changes: 98 additions & 0 deletions src/NHibernate/Linq/LinqExtensionMethods.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
using NHibernate.Engine;
using NHibernate.Exceptions;
using NHibernate.Hql.Ast.ANTLR;
using NHibernate.Impl;
using NHibernate.SqlCommand;
using NHibernate.Type;
using Remotion.Linq;
using Remotion.Linq.Parsing.ExpressionTreeVisitors;
Expand All @@ -11,6 +19,96 @@ namespace NHibernate.Linq
{
public static class LinqExtensionMethods
{
public static Int32 Delete<T>(this ISession session, Expression<Func<T, Boolean>> condition)
{
//these could be cached as static readonly fields
var instanceBindingFlags = BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance;
var staticBindingFlags = BindingFlags.Public | BindingFlags.Static;
var selectMethod = typeof(Queryable).GetMethods(staticBindingFlags).First(x => x.Name == "Select");
var whereMethod = typeof(Queryable).GetMethods(staticBindingFlags).First(x => x.Name == "Where");
var translatorFactory = new ASTQueryTranslatorFactory();
var aliasRegex = new Regex(" from (\\w+) (\\w+) ");
var parameterTokensRegex = new Regex("\\?");

var entityType = typeof(T);
var queryable = session.Query<T>();
var sessionImpl = session.GetSessionImplementation();
var persister = sessionImpl.GetEntityPersister(entityType.FullName, null);
var idName = persister.IdentifierPropertyName;
var idType = persister.IdentifierType.ReturnedClass;
var idProperty = entityType.GetProperty(idName, instanceBindingFlags);
var idMember = idProperty as MemberInfo;

if (idProperty == null)
{
var fieldEntityType = entityType;

//if the property is null, it means the the id is implemented as a field
while ((fieldEntityType != typeof(Object)) && (idMember == null))
{
//try to find the field recursively
idMember = fieldEntityType.GetField(idName, instanceBindingFlags);

fieldEntityType = fieldEntityType.BaseType;
}
}

if (idMember == null)
{
throw new InvalidOperationException(string.Format("Could not find identity property {0} in entity {1}.", idName, entityType.FullName));
}

var delegateType = typeof(Func<,>).MakeGenericType(entityType, idType);
var parm = Expression.Parameter(entityType, "x");
var lambda = Expression.Lambda(delegateType, Expression.MakeMemberAccess(parm, idMember), new ParameterExpression[] { parm });
var where = Expression.Call(null, whereMethod.MakeGenericMethod(entityType), queryable.Expression, condition);
var call = Expression.Call(null, selectMethod.MakeGenericMethod(entityType, idType), where, lambda);

var nhLinqExpression = new NhLinqExpression(call, sessionImpl.Factory);
var translator = translatorFactory.CreateQueryTranslators(nhLinqExpression, null, false, sessionImpl.EnabledFilters, sessionImpl.Factory).Single();
var parameters = nhLinqExpression.ParameterValuesByName.Select(x => x.Value.Item1).ToArray();
//we need to turn positional parameters into named parameters because of SetParameterList
var count = 0;
var replacedSql = parameterTokensRegex.Replace(translator.SQLString, m => ":p" + count++);
var sql = new StringBuilder(replacedSql);
//find from
var fromIndex = sql.ToString().IndexOf(" from ", StringComparison.InvariantCultureIgnoreCase);
//find alias
var alias = aliasRegex.Match(sql.ToString()).Groups[2].Value;

//make a string in the form DELETE alias FROM table alias WHERE condition
sql.Remove(0, fromIndex);
sql.Insert(0, string.Concat("delete ", alias, " "));

using (var childSession = session.GetSession(session.ActiveEntityMode))
{
try
{
var query = childSession.CreateSQLQuery(sql.ToString());

for (var i = 0; i < parameters.Length; ++i)
{
var parameter = parameters[i];

if (!(parameter is IEnumerable) || (parameter is string) || (parameter is byte[]))
{
query.SetParameter(String.Format("p{0}", i), parameter);
}
else
{
query.SetParameterList(String.Format("p{0}", i), parameter as IEnumerable);
}
}

return query.ExecuteUpdate();
}
catch (Exception ex)
{
throw ADOExceptionHelper.Convert(sessionImpl.Factory.SQLExceptionConverter, ex, "Error deleting records.", new SqlString(sql.ToString()), parameters, nhLinqExpression.ParameterValuesByName.ToDictionary(x => x.Key, x => new TypedValue(x.Value.Item2, x.Value.Item1, session.ActiveEntityMode)));
}
}
}

public static IQueryable<T> Query<T>(this ISession session)
{
return new NhQueryable<T>(session.GetSessionImplementation());
Expand Down