Skip to content

Commit b7a58da

Browse files
committed
Added escape character support in SqlMethods.Like
1 parent a6c78c3 commit b7a58da

File tree

5 files changed

+120
-44
lines changed

5 files changed

+120
-44
lines changed

src/NHibernate.Test/Linq/FunctionTests.cs

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,38 @@ where NHibernate.Linq.SqlMethods.Like(e.FirstName, "Ma%et")
2323
Assert.That(query[0].FirstName, Is.EqualTo("Margaret"));
2424
}
2525

26+
[Test]
27+
public void LikeFunctionWithEscapeCharacter()
28+
{
29+
using (var tx = session.BeginTransaction())
30+
{
31+
var employeeName = "Mar%aret";
32+
var escapeChar = '\\';
33+
var employeeNameEscaped = employeeName.Replace("%", escapeChar + "%");
34+
35+
//This entity will be flushed to the db, but rolled back when the test completes
36+
37+
session.Save(new Employee { FirstName = employeeName, LastName = "" });
38+
session.Flush();
39+
40+
41+
var query = (from e in db.Employees
42+
where NHibernate.Linq.SqlMethods.Like(e.FirstName, employeeNameEscaped, escapeChar)
43+
select e).ToList();
44+
45+
Assert.That(query.Count, Is.EqualTo(1));
46+
Assert.That(query[0].FirstName, Is.EqualTo(employeeName));
47+
48+
Assert.Throws<ArgumentException>(() =>
49+
{
50+
(from e in db.Employees
51+
where NHibernate.Linq.SqlMethods.Like(e.FirstName, employeeNameEscaped, e.FirstName.First())
52+
select e).ToList();
53+
});
54+
tx.Rollback();
55+
}
56+
}
57+
2658
private static class SqlMethods
2759
{
2860
public static bool Like(string expression, string pattern)
@@ -48,8 +80,8 @@ where NHibernate.Test.Linq.FunctionTests.SqlMethods.Like(e.FirstName, "Ma%et")
4880
public void SubstringFunction2()
4981
{
5082
var query = (from e in db.Employees
51-
where e.FirstName.Substring(0, 2) == "An"
52-
select e).ToList();
83+
where e.FirstName.Substring(0, 2) == "An"
84+
select e).ToList();
5385

5486
Assert.That(query.Count, Is.EqualTo(2));
5587
}
@@ -58,8 +90,8 @@ where e.FirstName.Substring(0, 2) == "An"
5890
public void SubstringFunction1()
5991
{
6092
var query = (from e in db.Employees
61-
where e.FirstName.Substring(3) == "rew"
62-
select e).ToList();
93+
where e.FirstName.Substring(3) == "rew"
94+
select e).ToList();
6395

6496
Assert.That(query.Count, Is.EqualTo(1));
6597
Assert.That(query[0].FirstName, Is.EqualTo("Andrew"));
@@ -83,12 +115,12 @@ public void ReplaceFunction()
83115
var query = from e in db.Employees
84116
where e.FirstName.StartsWith("An")
85117
select new
86-
{
87-
Before = e.FirstName,
88-
AfterMethod = e.FirstName.Replace("An", "Zan"),
89-
AfterExtension = ExtensionMethods.Replace(e.FirstName, "An", "Zan"),
90-
AfterExtension2 = e.FirstName.ReplaceExtension("An", "Zan")
91-
};
118+
{
119+
Before = e.FirstName,
120+
AfterMethod = e.FirstName.Replace("An", "Zan"),
121+
AfterExtension = ExtensionMethods.Replace(e.FirstName, "An", "Zan"),
122+
AfterExtension2 = e.FirstName.ReplaceExtension("An", "Zan")
123+
};
92124

93125
var s = ObjectDumper.Write(query);
94126
}
@@ -124,7 +156,7 @@ public void IndexOfFunctionProjection()
124156
{
125157
if (!TestDialect.SupportsLocate)
126158
Assert.Ignore("Locate function not supported.");
127-
159+
128160
var query = from e in db.Employees
129161
where e.FirstName.Contains("a")
130162
select e.FirstName.IndexOf('A', 3);
@@ -139,7 +171,7 @@ public void TwoFunctionExpression()
139171
Assert.Ignore("Locate function not supported.");
140172

141173
var query = from e in db.Employees
142-
where e.FirstName.IndexOf("A") == e.BirthDate.Value.Month
174+
where e.FirstName.IndexOf("A") == e.BirthDate.Value.Month
143175
select e.FirstName;
144176

145177
ObjectDumper.Write(query);
@@ -176,9 +208,9 @@ public void Trim()
176208
{
177209
using (session.BeginTransaction())
178210
{
179-
AnotherEntity ae1 = new AnotherEntity {Input = " hi "};
180-
AnotherEntity ae2 = new AnotherEntity {Input = "hi"};
181-
AnotherEntity ae3 = new AnotherEntity {Input = "heh"};
211+
AnotherEntity ae1 = new AnotherEntity { Input = " hi " };
212+
AnotherEntity ae2 = new AnotherEntity { Input = "hi" };
213+
AnotherEntity ae3 = new AnotherEntity { Input = "heh" };
182214
session.Save(ae1);
183215
session.Save(ae2);
184216
session.Save(ae3);
@@ -201,9 +233,9 @@ public void TrimTrailingWhitespace()
201233
{
202234
using (session.BeginTransaction())
203235
{
204-
session.Save(new AnotherEntity {Input = " hi "});
205-
session.Save(new AnotherEntity {Input = "hi"});
206-
session.Save(new AnotherEntity {Input = "heh"});
236+
session.Save(new AnotherEntity { Input = " hi " });
237+
session.Save(new AnotherEntity { Input = "hi" });
238+
session.Save(new AnotherEntity { Input = "heh" });
207239
session.Flush();
208240

209241
Assert.AreEqual(TestDialect.IgnoresTrailingWhitespace ? 2 : 1, session.Query<AnotherEntity>().Where(e => e.Input.TrimStart() == "hi ").Count());
@@ -256,7 +288,7 @@ public void WhereBoolConstantEqual()
256288
var query = from item in db.Role
257289
where item.IsActive.Equals(true)
258290
select item;
259-
291+
260292
ObjectDumper.Write(query);
261293
}
262294

@@ -266,7 +298,7 @@ public void WhereBoolParameterEqual()
266298
var query = from item in db.Role
267299
where item.IsActive.Equals(1 == 1)
268300
select item;
269-
301+
270302
ObjectDumper.Write(query);
271303
}
272304

@@ -286,8 +318,8 @@ where item.IsActive.Equals(f())
286318
public void WhereLongEqual()
287319
{
288320
var query = from item in db.PatientRecords
289-
where item.Id.Equals(-1)
290-
select item;
321+
where item.Id.Equals(-1)
322+
select item;
291323

292324
ObjectDumper.Write(query);
293325
}
@@ -301,7 +333,7 @@ where item.RegisteredAt.Equals(DateTime.Today)
301333

302334
ObjectDumper.Write(query);
303335
}
304-
336+
305337
[Test]
306338
public void WhereGuidEqual()
307339
{
@@ -310,7 +342,7 @@ where item.Reference.Equals(Guid.Empty)
310342
select item;
311343

312344
ObjectDumper.Write(query);
313-
}
345+
}
314346

315347
[Test]
316348
public void WhereDoubleEqual()
@@ -320,8 +352,8 @@ where item.BodyWeight.Equals(-1)
320352
select item;
321353

322354
ObjectDumper.Write(query);
323-
}
324-
355+
}
356+
325357
[Test]
326358
public void WhereFloatEqual()
327359
{
@@ -330,7 +362,7 @@ where item.Float.Equals(-1)
330362
select item;
331363

332364
ObjectDumper.Write(query);
333-
}
365+
}
334366

335367
[Test]
336368
public void WhereCharEqual()
@@ -362,4 +394,4 @@ where item.Discount.Equals(-1)
362394
ObjectDumper.Write(query);
363395
}
364396
}
365-
}
397+
}

src/NHibernate/Hql/Ast/HqlTreeBuilder.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,11 @@ public HqlLike Like(HqlExpression lhs, HqlExpression rhs)
356356
return new HqlLike(_factory, lhs, rhs);
357357
}
358358

359+
public HqlLike Like(HqlExpression lhs, HqlExpression rhs, HqlConstant escapeCharacter)
360+
{
361+
return new HqlLike(_factory, lhs, rhs, escapeCharacter);
362+
}
363+
359364
public HqlConcat Concat(params HqlExpression[] args)
360365
{
361366
return new HqlConcat(_factory, args);
@@ -466,4 +471,4 @@ public HqlTreeNode Indices(HqlExpression dictionary)
466471
return new HqlIndices(_factory, dictionary);
467472
}
468473
}
469-
}
474+
}

src/NHibernate/Hql/Ast/HqlTreeNode.cs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ protected HqlTreeNode(int type, string text, IASTFactory factory, IEnumerable<Hq
2222
AddChildren(children);
2323
}
2424

25-
protected HqlTreeNode(int type, string text, IASTFactory factory, params HqlTreeNode[] children) : this(type, text, factory, (IEnumerable<HqlTreeNode>) children)
25+
protected HqlTreeNode(int type, string text, IASTFactory factory, params HqlTreeNode[] children) : this(type, text, factory, (IEnumerable<HqlTreeNode>)children)
2626
{
2727
}
2828

@@ -92,7 +92,7 @@ internal IASTNode AstNode
9292

9393
internal void AddChild(HqlTreeNode child)
9494
{
95-
if (child is HqlExpressionSubTreeHolder)
95+
if (child is HqlExpressionSubTreeHolder)
9696
{
9797
AddChildren(child.Children);
9898
}
@@ -116,7 +116,7 @@ public static HqlBooleanExpression AsBooleanExpression(this HqlTreeNode node)
116116
{
117117
if (node is HqlDot)
118118
{
119-
return new HqlBooleanDot(node.Factory, (HqlDot) node);
119+
return new HqlBooleanDot(node.Factory, (HqlDot)node);
120120
}
121121

122122
// TODO - nice error handling if cast fails
@@ -220,8 +220,8 @@ internal HqlIdent(IASTFactory factory, System.Type type)
220220
}
221221
if (type == typeof(DateTimeOffset))
222222
{
223-
SetText("datetimeoffset");
224-
break;
223+
SetText("datetimeoffset");
224+
break;
225225
}
226226
throw new NotSupportedException(string.Format("Don't currently support idents of type {0}", type.Name));
227227
}
@@ -373,7 +373,7 @@ public HqlSkip(IASTFactory factory, HqlExpression parameter)
373373
public class HqlTake : HqlStatement
374374
{
375375
public HqlTake(IASTFactory factory, HqlExpression parameter)
376-
: base(HqlSqlWalker.TAKE, "take", factory, parameter) {}
376+
: base(HqlSqlWalker.TAKE, "take", factory, parameter) { }
377377
}
378378

379379
public class HqlConstant : HqlExpression
@@ -690,7 +690,7 @@ public HqlMax(IASTFactory factory, HqlExpression expression)
690690
: base(HqlSqlWalker.AGGREGATE, "max", factory, expression)
691691
{
692692
}
693-
}
693+
}
694694

695695
public class HqlMin : HqlExpression
696696
{
@@ -818,6 +818,19 @@ public HqlLike(IASTFactory factory, HqlExpression lhs, HqlExpression rhs)
818818
: base(HqlSqlWalker.LIKE, "like", factory, lhs, rhs)
819819
{
820820
}
821+
822+
public HqlLike(IASTFactory factory, HqlExpression lhs, HqlExpression rhs, HqlConstant escapeCharacter)
823+
: base(HqlSqlWalker.LIKE, "like", factory, lhs, rhs, new HqlEscape(factory, escapeCharacter))
824+
{
825+
}
826+
}
827+
828+
public class HqlEscape : HqlStatement
829+
{
830+
public HqlEscape(IASTFactory factory, HqlConstant escapeCharacter)
831+
: base(HqlSqlWalker.ESCAPE, "escape", factory, escapeCharacter)
832+
{
833+
}
821834
}
822835

823836
public class HqlConcat : HqlExpression
@@ -899,4 +912,4 @@ public HqlInList(IASTFactory factory, HqlTreeNode source)
899912
{
900913
}
901914
}
902-
}
915+
}

src/NHibernate/Linq/Functions/StringGenerator.cs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,22 @@ public IEnumerable<MethodInfo> SupportedMethods
1818
public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments,
1919
HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
2020
{
21-
return treeBuilder.Like(
22-
visitor.Visit(arguments[0]).AsExpression(),
23-
visitor.Visit(arguments[1]).AsExpression());
21+
if (arguments.Count == 2)
22+
{
23+
return treeBuilder.Like(
24+
visitor.Visit(arguments[0]).AsExpression(),
25+
visitor.Visit(arguments[1]).AsExpression());
26+
}
27+
if (arguments[2].NodeType == ExpressionType.Constant)
28+
{
29+
var escapeCharExpression = (ConstantExpression)arguments[2];
30+
return treeBuilder.Like(
31+
visitor.Visit(arguments[0]).AsExpression(),
32+
visitor.Visit(arguments[1]).AsExpression(),
33+
treeBuilder.Constant(escapeCharExpression.Value));
34+
}
35+
throw new ArgumentException("The escape character must be specified as literal value or a string variable");
36+
2437
}
2538

2639
public bool SupportsMethod(MethodInfo method)
@@ -34,8 +47,8 @@ public bool SupportsMethod(MethodInfo method)
3447
// to avoid referencing Linq2Sql or Linq2NHibernate, if they so wish.
3548

3649
return method != null && method.Name == "Like" &&
37-
method.GetParameters().Length == 2 &&
38-
method.DeclaringType != null &&
50+
(method.GetParameters().Length == 2 || method.GetParameters().Length == 3) &&
51+
method.DeclaringType != null &&
3952
method.DeclaringType.FullName.EndsWith("SqlMethods");
4053
}
4154

@@ -284,4 +297,4 @@ public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnly
284297
return treeBuilder.MethodCall("str", visitor.Visit(targetObject).AsExpression());
285298
}
286299
}
287-
}
300+
}

src/NHibernate/Linq/SqlMethods.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace NHibernate.Linq
55
public static class SqlMethods
66
{
77
/// <summary>
8-
/// Use the SqlMethods.Like() method in a Linq2NHibernate expression to generate
8+
/// Use this method in a Linq2NHibernate expression to generate
99
/// an SQL LIKE expression. (If you want to avoid depending on the NHibernate.Linq namespace,
1010
/// you can define your own replica of this method. Any 2-argument method named Like in a class named SqlMethods
1111
/// will be translated.) This method can only be used in Linq2NHibernate expression, and will throw
@@ -16,5 +16,18 @@ public static bool Like(this string matchExpression, string sqlLikePattern)
1616
throw new NotSupportedException(
1717
"The NHibernate.Linq.SqlMethods.Like(string, string) method can only be used in Linq2NHibernate expressions.");
1818
}
19+
20+
/// <summary>
21+
/// Use this method in a Linq2NHibernate expression to generate
22+
/// an SQL LIKE expression with an escpare character defined. (If you want to avoid depending on the NHibernate.Linq namespace,
23+
/// you can define your own replica of this method. Any 3-argument method named Like in a class named SqlMethods
24+
/// will be translated.) This method can only be used in Linq2NHibernate expression, and will throw
25+
/// if called directly.
26+
/// </summary>
27+
public static bool Like(this string matchExpression, string sqlLikePattern, char escapeCharacter)
28+
{
29+
throw new NotSupportedException(
30+
"The NHibernate.Linq.SqlMethods.Like(string, string) method can only be used in Linq2NHibernate expressions.");
31+
}
1932
}
2033
}

0 commit comments

Comments
 (0)