Skip to content

Commit f278278

Browse files
committed
Proper support for IN clause for composite properties in Criteria
1 parent bf2f7f2 commit f278278

File tree

5 files changed

+148
-52
lines changed

5 files changed

+148
-52
lines changed

src/NHibernate.Test/Async/CompositeId/ClassWithCompositeIdFixture.cs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,6 @@ protected override string[] Mappings
3838
get { return new string[] {"CompositeId.ClassWithCompositeId.hbm.xml"}; }
3939
}
4040

41-
protected override bool AppliesTo(Dialect.Dialect dialect)
42-
{
43-
return !(dialect is Dialect.FirebirdDialect); // Firebird has no CommandTimeout, and locks up during the tear-down of this fixture
44-
}
45-
4641
protected override void OnSetUp()
4742
{
4843
id = new Id("stringKey", 3, firstDateTime);
@@ -52,9 +47,11 @@ protected override void OnSetUp()
5247
protected override void OnTearDown()
5348
{
5449
using (ISession s = Sfi.OpenSession())
50+
using (var t = s.BeginTransaction())
5551
{
5652
s.Delete("from ClassWithCompositeId");
5753
s.Flush();
54+
t.Commit();
5855
}
5956
}
6057

@@ -397,5 +394,26 @@ public async Task QueryOverOrderByAndWhereWithIdProjectionDoesntThrowAsync()
397394
Assert.That(results.Count, Is.EqualTo(1));
398395
}
399396
}
397+
398+
[Test]
399+
public async Task QueryOverInClauseAsync()
400+
{
401+
// insert the new objects
402+
using (ISession s = OpenSession())
403+
using (ITransaction t = s.BeginTransaction())
404+
{
405+
await (s.SaveAsync(new ClassWithCompositeId(id) {OneProperty = 5}));
406+
await (s.SaveAsync(new ClassWithCompositeId(secondId) {OneProperty = 10}));
407+
await (s.SaveAsync(new ClassWithCompositeId(new Id(id.KeyString, id.GetKeyShort(), secondId.KeyDateTime))));
408+
409+
await (t.CommitAsync());
410+
}
411+
412+
using (var s = OpenSession())
413+
{
414+
var results = await (s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).IsIn(new[] {id, secondId}).ListAsync());
415+
Assert.That(results.Count, Is.EqualTo(2));
416+
}
417+
}
400418
}
401419
}

src/NHibernate.Test/CompositeId/ClassWithCompositeIdFixture.cs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ protected override string[] Mappings
2727
get { return new string[] {"CompositeId.ClassWithCompositeId.hbm.xml"}; }
2828
}
2929

30-
protected override bool AppliesTo(Dialect.Dialect dialect)
31-
{
32-
return !(dialect is Dialect.FirebirdDialect); // Firebird has no CommandTimeout, and locks up during the tear-down of this fixture
33-
}
34-
3530
protected override void OnSetUp()
3631
{
3732
id = new Id("stringKey", 3, firstDateTime);
@@ -41,9 +36,11 @@ protected override void OnSetUp()
4136
protected override void OnTearDown()
4237
{
4338
using (ISession s = Sfi.OpenSession())
39+
using (var t = s.BeginTransaction())
4440
{
4541
s.Delete("from ClassWithCompositeId");
4642
s.Flush();
43+
t.Commit();
4744
}
4845
}
4946

@@ -386,5 +383,26 @@ public void QueryOverOrderByAndWhereWithIdProjectionDoesntThrow()
386383
Assert.That(results.Count, Is.EqualTo(1));
387384
}
388385
}
386+
387+
[Test]
388+
public void QueryOverInClause()
389+
{
390+
// insert the new objects
391+
using (ISession s = OpenSession())
392+
using (ITransaction t = s.BeginTransaction())
393+
{
394+
s.Save(new ClassWithCompositeId(id) {OneProperty = 5});
395+
s.Save(new ClassWithCompositeId(secondId) {OneProperty = 10});
396+
s.Save(new ClassWithCompositeId(new Id(id.KeyString, id.GetKeyShort(), secondId.KeyDateTime)));
397+
398+
t.Commit();
399+
}
400+
401+
using (var s = OpenSession())
402+
{
403+
var results = s.QueryOver<ClassWithCompositeId>().WhereRestrictionOn(p => p.Id).IsIn(new[] {id, secondId}).List();
404+
Assert.That(results.Count, Is.EqualTo(2));
405+
}
406+
}
389407
}
390408
}

src/NHibernate/Criterion/InExpression.cs

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using System;
2-
using System.Collections;
32
using System.Collections.Generic;
43
using System.Linq;
54
using NHibernate.Engine;
@@ -13,9 +12,6 @@ namespace NHibernate.Criterion
1312
/// An <see cref="ICriterion"/> that constrains the property
1413
/// to a specified list of values.
1514
/// </summary>
16-
/// <remarks>
17-
/// InExpression - should only be used with a Single Value column - no multicolumn properties...
18-
/// </remarks>
1915
[Serializable]
2016
public class InExpression : AbstractCriterion
2117
{
@@ -62,41 +58,50 @@ public override SqlString ToSqlString(ICriteria criteria, ICriteriaQuery criteri
6258
return new SqlString("1=0");
6359
}
6460

65-
//TODO: add default capacity
66-
SqlStringBuilder result = new SqlStringBuilder();
67-
SqlString[] columnNames =
68-
CriterionUtil.GetColumnNames(_propertyName, _projection, criteriaQuery, criteria);
69-
70-
// Generate SqlString of the form:
71-
// columnName1 in (values) and columnName2 in (values) and ...
72-
Parameter[] parameters = GetParameterTypedValues(criteria, criteriaQuery).SelectMany(t => criteriaQuery.NewQueryParameter(t)).ToArray();
61+
SqlString[] columns = CriterionUtil.GetColumnNames(_propertyName, _projection, criteriaQuery, criteria);
7362

74-
for (int columnIndex = 0; columnIndex < columnNames.Length; columnIndex++)
63+
var list = new List<Parameter>(columns.Length * Values.Length);
64+
foreach (var typedValue in GetParameterTypedValues(criteria, criteriaQuery))
7565
{
76-
SqlString columnName = columnNames[columnIndex];
77-
78-
if (columnIndex > 0)
79-
{
80-
result.Add(" and ");
81-
}
66+
//Must be executed after CriterionUtil.GetColumnNames (as it might add _projection parameters to criteria)
67+
list.AddRange(criteriaQuery.NewQueryParameter(typedValue));
68+
}
8269

83-
result
84-
.Add(columnName)
85-
.Add(" in (");
70+
var bogusParam = Parameter.Placeholder;
8671

87-
for (int i = 0; i < _values.Length; i++)
88-
{
89-
if (i > 0)
90-
{
91-
result.Add(StringHelper.CommaSpace);
92-
}
93-
result.Add(parameters[i]);
94-
}
72+
var sqlString = GetSqlString(criteriaQuery, columns, bogusParam);
73+
sqlString.SubstituteBogusParameters(list, bogusParam);
74+
return sqlString;
75+
}
9576

96-
result.Add(")");
77+
private SqlString GetSqlString(ICriteriaQuery criteriaQuery, SqlString[] columns, Parameter bogusParam)
78+
{
79+
if (columns.Length <= 1 || criteriaQuery.Factory.Dialect.SupportsRowValueConstructorSyntaxInInList)
80+
{
81+
var parens = columns.Length > 1 ? new[] {new SqlString("("), new SqlString(")"),} : null;
82+
SqlString comaSeparator = new SqlString(", ");
83+
var singleValueParam = SqlStringHelper.Repeat(new SqlString(bogusParam), columns.Length, comaSeparator, parens);
84+
85+
var parameters = SqlStringHelper.Repeat(singleValueParam, Values.Length, comaSeparator, null);
86+
87+
//single column: col1 in (?, ?)
88+
//multi column: (col1, col2) in ((?, ?), (?, ?))
89+
return new SqlString(
90+
parens?[0] ?? SqlString.Empty,
91+
SqlStringHelper.Join(comaSeparator, columns),
92+
parens?[1] ?? SqlString.Empty,
93+
" in (",
94+
parameters,
95+
")");
9796
}
9897

99-
return result.ToSqlString();
98+
//((col1 = ? and col2 = ?) or (col1 = ? and col2 = ?))
99+
var cols = new SqlString(
100+
" ( ",
101+
SqlStringHelper.Join(new SqlString(" = ", bogusParam, " and "), columns),
102+
new SqlString("= ", bogusParam, " ) "));
103+
cols = SqlStringHelper.Repeat(cols, Values.Length, "or ", new[] {" ( ", " ) "});
104+
return cols;
100105
}
101106

102107
private void AssertPropertyIsNotCollection(ICriteriaQuery criteriaQuery, ICriteria criteria)
@@ -127,16 +132,13 @@ private List<TypedValue> GetParameterTypedValues(ICriteria criteria, ICriteriaQu
127132
List<TypedValue> list = new List<TypedValue>();
128133
IAbstractComponentType actype = (IAbstractComponentType) type;
129134
IType[] types = actype.Subtypes;
130-
131-
for (int i = 0; i < types.Length; i++)
135+
for (int vi = 0; vi < _values.Length; vi++)
136+
for (int ti = 0; ti < types.Length; ti++)
132137
{
133-
for (int j = 0; j < _values.Length; j++)
134-
{
135-
object subval = _values[j] == null
136-
? null
137-
: actype.GetPropertyValues(_values[j])[i];
138-
list.Add(new TypedValue(types[i], subval, false));
139-
}
138+
object subval = _values[vi] == null
139+
? null
140+
: actype.GetPropertyValues(_values[vi])[ti];
141+
list.Add(new TypedValue(types[ti], subval, false));
140142
}
141143

142144
return list;

src/NHibernate/SqlCommand/SqlString.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,19 @@ public SqlString GetSubselectString()
10241024
return new SubselectClauseExtractor(this).GetSqlString();
10251025
}
10261026

1027+
internal void SubstituteBogusParameters(IReadOnlyList<Parameter> actualParams, Parameter bogusParam)
1028+
{
1029+
int index = 0;
1030+
var keys = _parameters.Keys;
1031+
// ReSharper disable once ForCanBeConvertedToForeach
1032+
for (var i = 0; i < keys.Count; i++)
1033+
{
1034+
var key = keys[i];
1035+
if (ReferenceEquals(_parameters[key], bogusParam))
1036+
_parameters[key] = actualParams[index++];
1037+
}
1038+
}
1039+
10271040
[Serializable]
10281041
private struct Part : IEquatable<Part>
10291042
{

src/NHibernate/SqlCommand/SqlStringHelper.cs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ public static bool IsNotEmpty(SqlString str)
5656
return !IsEmpty(str);
5757
}
5858

59-
6059
public static bool IsEmpty(SqlString str)
6160
{
6261
return str == null || str.Count == 0;
@@ -90,5 +89,51 @@ internal static SqlString ParametersList(List<Parameter> parameters)
9089

9190
return builder.ToSqlString();
9291
}
92+
93+
internal static SqlString Repeat(SqlString placeholder, int count, string separator, string[] wrapResult)
94+
{
95+
return Repeat(
96+
placeholder,
97+
count,
98+
new SqlString(separator),
99+
wrapResult == null
100+
? null
101+
: new[]
102+
{
103+
new SqlString(wrapResult[0]),
104+
new SqlString(wrapResult[1]),
105+
});
106+
}
107+
108+
internal static SqlString Repeat(SqlString placeholder, int count, SqlString separator, SqlString[] wrapResult)
109+
{
110+
if (wrapResult == null)
111+
{
112+
if (count == 0)
113+
return SqlString.Empty;
114+
if (count == 1)
115+
return placeholder;
116+
}
117+
118+
var builder = new SqlStringBuilder(count * 2 + 1);
119+
if (wrapResult != null)
120+
{
121+
builder.Add(wrapResult[0]);
122+
}
123+
124+
if (count > 0)
125+
builder.Add(placeholder);
126+
127+
for (int i = 1; i < count; i++)
128+
{
129+
builder.Add(separator).Add(placeholder);
130+
}
131+
132+
if (wrapResult != null)
133+
{
134+
builder.Add(wrapResult[1]);
135+
}
136+
return builder.ToSqlString();
137+
}
93138
}
94139
}

0 commit comments

Comments
 (0)