Skip to content

Commit 667e9fd

Browse files
maca88hazzik
authored andcommitted
addded async support for linq with generated tests
1 parent 925d5fd commit 667e9fd

File tree

199 files changed

+12439
-421
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

199 files changed

+12439
-421
lines changed

src/NHibernate.AsyncGenerator/AsyncConfiguration.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,5 @@ public class AsyncConfiguration
1111
public AsyncLockConfiguration Lock { get; set; }
1212

1313
public AsyncCustomTaskTypeConfiguration CustomTaskType { get; set; } = new AsyncCustomTaskTypeConfiguration();
14-
15-
public string AttributeName { get; set; }
1614
}
1715
}

src/NHibernate.AsyncGenerator/DocumentTransformer.cs

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace NHibernate.AsyncGenerator
1616
{
1717
public class DocumentTransformer
1818
{
19-
private List<ITransformerPlugin> _plugins = new List<ITransformerPlugin>();
19+
private readonly List<ITransformerPlugin> _plugins = new List<ITransformerPlugin>();
2020

2121
private class TransformedNode
2222
{
@@ -167,7 +167,6 @@ private MethodDeclarationSyntax TransformMethod(MethodInfo methodInfo)
167167
if (!methodInfo.HasBody)
168168
{
169169
return methodInfo.Node
170-
.WithoutAttribute("Async")
171170
.ReturnAsTask(methodInfo.Symbol, taskConflict)
172171
.WithIdentifier(Identifier(methodInfo.Node.Identifier.Value + "Async"))
173172
.RemoveLeadingRegions();
@@ -327,13 +326,26 @@ private MethodDeclarationSyntax TransformMethod(MethodInfo methodInfo)
327326
methodInfo.Node.TypeParameterList.Parameters.Select(o => IdentifierName(o.Identifier.ValueText))
328327
)))
329328
: (SimpleNameSyntax)IdentifierName(methodInfo.Node.Identifier.ValueText);
329+
MemberAccessExpressionSyntax accessExpression = null;
330+
if (methodInfo.Symbol.MethodKind == MethodKind.ExplicitInterfaceImplementation)
331+
{
332+
// Explicit implementations needs an explicit cast (ie. ((Type)this).SyncMethod() )
333+
accessExpression = MemberAccessExpression(
334+
SyntaxKind.SimpleMemberAccessExpression,
335+
ParenthesizedExpression(
336+
CastExpression(
337+
IdentifierName(methodInfo.Symbol.ExplicitInterfaceImplementations.Single().ContainingType.Name),
338+
ThisExpression())),
339+
name);
340+
}
341+
330342

331-
var invocation = InvocationExpression(name)
343+
var invocation = InvocationExpression(accessExpression ?? (ExpressionSyntax)name)
332344
.WithArgumentList(
333345
ArgumentList(
334346
SeparatedList(
335347
methodInfo.Node.ParameterList.Parameters
336-
.Select(o => Argument(IdentifierName(o.Identifier.ValueText)))
348+
.Select(o => Argument(IdentifierName(o.Identifier.Text)))
337349
)));
338350
if (methodInfo.Symbol.ReturnsVoid)
339351
{
@@ -364,7 +376,6 @@ private MethodDeclarationSyntax TransformMethod(MethodInfo methodInfo)
364376
.WithoutAttribute("MethodImpl");
365377
}
366378
methodNode = methodNode
367-
.WithoutAttribute("Async")
368379
.ReturnAsTask(methodInfo.Symbol, taskConflict)
369380
.WithIdentifier(Identifier(methodNode.Identifier.Value + "Async"));
370381

@@ -416,7 +427,7 @@ private TransformTypeResult TransformType(TypeInfo rootTypeInfo, bool onlyMissin
416427
var refSpanLength = reference.Location.SourceSpan.Length;
417428
if (refSpanStart < 0)
418429
{
419-
// cref
430+
// TODO: cref
420431
//var startSpan = reference.Location.SourceSpan.Start - rootTypeInfo.Node.GetLeadingTrivia().Span.Start;
421432
//var crefNode = leadingTrivia.First(o => o.SpanStart == startSpan && o.Span.Length == refSpanLength);
422433
continue;
@@ -512,7 +523,7 @@ private TransformTypeResult TransformType(TypeInfo rootTypeInfo, bool onlyMissin
512523
}
513524
// add missing members
514525
var typeNode = rootTypeNode.GetAnnotatedNodes(metadata.NodeAnnotation).OfType<TypeDeclarationSyntax>().First();
515-
526+
// TODO: we should not include all members, we need to skip the methods that are not required
516527
rootTypeNode = rootTypeNode.ReplaceNode(typeNode, typeNode.WithMembers(
517528
typeNode.Members.Select(o => o.RemoveLeadingRegions()).ToSyntaxList()
518529
.AddRange(
@@ -616,6 +627,7 @@ public DocumentTransformationResult Transform()
616627
var taskConflict = false;
617628
var asyncLockUsed = false;
618629
var rewrittenNodes = new List<TransformedNode>();
630+
var projectConfig = DocumentInfo.ProjectInfo.Configuration;
619631

620632
// TODO: handle global namespace
621633
foreach (var namespaceInfo in DocumentInfo.Values.OrderBy(o => o.Node.SpanStart))
@@ -695,6 +707,16 @@ public DocumentTransformationResult Transform()
695707
rootNode = rootNode.AddUsings(UsingDirective(NameSyntax(lockNamespace)));
696708
}
697709

710+
var usingList = projectConfig.GetAdditionalUsings?.Invoke(DocumentInfo.RootNode);
711+
if (usingList != null)
712+
{
713+
foreach (var us in usingList.Where(o => DocumentInfo.RootNode.Usings.All(u => u.Name.ToString() != o)))
714+
{
715+
rootNode = rootNode.AddUsings(UsingDirective(NameSyntax(us)));
716+
}
717+
}
718+
719+
698720
if (!taskConflict && rootNode.Usings.All(o => o.Name.ToString() != "System.Threading.Tasks"))
699721
{
700722
rootNode = rootNode.AddUsings(UsingDirective(NameSyntax("System.Threading.Tasks")));

src/NHibernate.AsyncGenerator/MethodInfo.cs

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ public MethodInfo(TypeInfo typeInfo, IMethodSymbol symbol, MethodDeclarationSynt
3434

3535
public bool Missing { get; set; }
3636

37+
public bool Required { get; set; }
38+
39+
public MethodAsyncConversion? Conversion { get; set; }
40+
3741
public HashSet<MethodInfo> InvokedBy { get; } = new HashSet<MethodInfo>();
3842

3943
/// <summary>
@@ -103,7 +107,7 @@ public void CalculateIgnore(int deep = 0, HashSet<MethodInfo> processedMethodInf
103107
return;
104108
}
105109

106-
if (Missing || ImplementsAbstractMethod)
110+
if (Missing || ImplementsAbstractMethod || Required)
107111
{
108112
foreach (var refResult in ReferenceResults)
109113
{
@@ -148,20 +152,30 @@ public void CalculateIgnore(int deep = 0, HashSet<MethodInfo> processedMethodInf
148152
return;
149153
}
150154

151-
var dependencies = ReferenceResults
155+
var relRefResults = ReferenceResults
152156
.Union(GetAllRelatedMethods().ToList().SelectMany(o => o.ReferenceResults))
153157
.ToList();
154158

155-
foreach (var refResult in dependencies)
159+
foreach (var refResult in relRefResults)
160+
{
161+
refResult.CalculateIgnore(deep, processedMethodInfos);
162+
}
163+
Ignore = relRefResults.All(o => o.Ignore);
164+
165+
if (Ignore)
166+
{
167+
relRefResults = relRefResults.Except(ReferenceResults).ToList();
168+
}
169+
170+
foreach (var refResult in relRefResults)
156171
{
157172
// if the reference cannot be async we should preserve the original method. TODO: what if the method has dependecies that are async
158173
if (!refResult.CanBeAsync && refResult.MethodInfo != null)
159174
{
160175
refResult.MethodInfo.CanBeAsnyc = false;
161176
}
162-
refResult.CalculateIgnore(deep, processedMethodInfos);
163177
}
164-
Ignore = dependencies.All(o => o.Ignore);
178+
165179
_ignoreCalculating = false;
166180
_ignoreCalculated = true;
167181
}
@@ -272,12 +286,13 @@ public IEnumerable<MethodInfo> GetAllDependencies()
272286

273287
private void AnalyzeInvocationExpression(SyntaxNode node, SimpleNameSyntax nameNode, MethodReferenceResult result)
274288
{
275-
var selectClause = node.Ancestors()
276-
.FirstOrDefault(o => o.IsKind(SyntaxKind.SelectClause));
277-
if (selectClause != null) // await is not supported in select clause
289+
var queryExpression = node.Ancestors()
290+
.OfType<QueryExpressionSyntax>()
291+
.FirstOrDefault();
292+
if (queryExpression != null) // await is not supported in linq query
278293
{
279294
result.CanBeAsync = false;
280-
Logger.Warn($"Cannot await async method in a select clause:\r\n{selectClause}\r\n");
295+
Logger.Warn($"Cannot await async method in a query expression:\r\n{queryExpression}\r\n");
281296
return;
282297
}
283298
var docInfo = TypeInfo.NamespaceInfo.DocumentInfo;
@@ -311,13 +326,31 @@ private void AnalyzeInvocationExpression(SyntaxNode node, SimpleNameSyntax nameN
311326
return;
312327
}
313328

329+
330+
// Custom code TODO: move
331+
if (nameNode.Identifier.ToString() == "ToList")
332+
{
333+
var beforeToListExpression = ((MemberAccessExpressionSyntax)((InvocationExpressionSyntax)node).Expression).Expression;
334+
var operation = docInfo.SemanticModel.GetOperation(beforeToListExpression);
335+
if (operation == null)
336+
{
337+
result.CanBeAsync = false;
338+
Logger.Warn($"Cannot find operation for previous node of ToList:\r\n{beforeToListExpression}\r\n");
339+
}
340+
else if(operation.Type.Name != "IQueryable")
341+
{
342+
result.CanBeAsync = false;
343+
Logger.Warn($"Operation for previous node of ToList is not IQueryable:\r\n{operation.Type.Name}\r\n");
344+
}
345+
}
346+
// End custom code
347+
314348
var anonFunctionNode = node.Ancestors()
315349
.OfType<AnonymousFunctionExpressionSyntax>()
316350
.FirstOrDefault();
317351
if (anonFunctionNode?.AsyncKeyword.IsMissing == false)
318352
{
319-
// Custom code
320-
353+
// Custom code TODO: move
321354
var methodArgTypeInfo = ModelExtensions.GetTypeInfo(docInfo.SemanticModel, anonFunctionNode);
322355
var convertedType = methodArgTypeInfo.ConvertedType;
323356
if (convertedType != null && convertedType.ContainingAssembly.Name == "nunit.framework" &&
@@ -329,7 +362,7 @@ private void AnalyzeInvocationExpression(SyntaxNode node, SimpleNameSyntax nameN
329362
result.MakeAnonymousFunctionAsync = true;
330363
return;
331364
}
332-
//end
365+
// End custom code
333366

334367
result.CanBeAsync = false;
335368
Logger.Warn($"Cannot await async method in an non async anonymous function:\r\n{anonFunctionNode}\r\n");
@@ -346,14 +379,14 @@ private void AnalyzeArgumentExpression(SyntaxNode node, SimpleNameSyntax nameNod
346379
return;
347380
}
348381

349-
// Custom code
382+
// Custom code TODO: move
350383
var convertedType = methodArgTypeInfo.ConvertedType;
351384
if (convertedType.ContainingAssembly.Name == "nunit.framework" && convertedType.Name == "TestDelegate")
352385
{
353386
result.WrapInsideAsyncFunction = true;
354387
return;
355388
}
356-
//end
389+
// End custom code
357390

358391
var delegateMethod = (IMethodSymbol)methodArgTypeInfo.ConvertedType.GetMembers("Invoke").First();
359392

@@ -479,7 +512,7 @@ public MethodReferenceResult AnalyzeReference(ReferenceLocation reference)
479512
return result;
480513
}
481514

482-
public List<AsyncCounterpartMethod> FindAsyncCounterpartMethodsWhitinBody(Dictionary<IMethodSymbol, IMethodSymbol> methodAsyncConterparts = null)
515+
public async Task<List<AsyncCounterpartMethod>> FindAsyncCounterpartMethodsWhitinBody(Dictionary<IMethodSymbol, IMethodSymbol> methodAsyncConterparts = null)
483516
{
484517
var result = new List<AsyncCounterpartMethod>();
485518
if (Node.Body == null)
@@ -503,10 +536,11 @@ public List<AsyncCounterpartMethod> FindAsyncCounterpartMethodsWhitinBody(Dictio
503536
}
504537
else
505538
{
506-
var config = TypeInfo.NamespaceInfo.DocumentInfo.ProjectInfo.Configuration;
539+
var projectInfo = TypeInfo.NamespaceInfo.DocumentInfo.ProjectInfo;
540+
var config = projectInfo.Configuration;
507541
if (config.FindAsyncCounterpart != null)
508542
{
509-
asyncMethodSymbol = config.FindAsyncCounterpart.Invoke(methodSymbol.OriginalDefinition);
543+
asyncMethodSymbol = await config.FindAsyncCounterpart.Invoke(projectInfo.Project, methodSymbol.OriginalDefinition).ConfigureAwait(false);
510544
}
511545
else
512546
{
@@ -613,7 +647,9 @@ public void Analyze()
613647
{
614648
ReferenceResults.Add(AnalyzeReference(reference));
615649
}
616-
CanBeAsnyc = ReferenceResults.Any(o => o.CanBeAsync);
650+
// TODO: TypeTransformation should be removed
651+
CanBeAsnyc = ReferenceResults.Any(o => o.CanBeAsync) ||
652+
(Conversion == MethodAsyncConversion.ToAsync && TypeInfo.TypeTransformation == TypeTransformation.Partial);
617653

618654
// TODO: check if this is correct
619655
if (TypeInfo.TypeTransformation == TypeTransformation.Partial && GetAllRelatedMethods().ToList().Any(o => o.InvokedBy.Any()))

0 commit comments

Comments
 (0)