Skip to content

Support late bound results #34300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 76 additions & 5 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public static partial class RequestDelegateFactory
private static readonly MethodInfo ExecuteValueTaskOfStringMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteTaskResultOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteValueResultTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteObjectReturnMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteObjectReturn), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo GetRequiredServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!;
private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo StringResultWriteResponseAsyncMethod = GetMethodInfo<Func<HttpResponse, string, Task>>((response, text) => HttpResponseWritingExtensions.WriteAsync(response, text, default));
Expand Down Expand Up @@ -338,6 +339,21 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall,
{
return Expression.Block(methodCall, CompletedTaskExpr);
}
else if (returnType == typeof(object))
{
return Expression.Call(ExecuteObjectReturnMethod, methodCall, HttpContextExpr);
}
else if (returnType == typeof(ValueTask<object>))
{
// REVIEW: We can avoid this box if it becomes a performance issue
var box = Expression.TypeAs(methodCall, typeof(object));
return Expression.Call(ExecuteObjectReturnMethod, box, HttpContextExpr);
}
else if (returnType == typeof(Task<object>))
{
var convert = Expression.Convert(methodCall, typeof(object));
return Expression.Call(ExecuteObjectReturnMethod, convert, HttpContextExpr);
}
else if (AwaitableInfo.IsTypeAwaitable(returnType, out _))
{
if (returnType == typeof(Task))
Expand Down Expand Up @@ -632,6 +648,61 @@ private static MemberInfo GetMemberInfo<T>(Expression<T> expr)
return mc.Member;
}

// The result of the method is null so we fallback to some runtime logic.
// First we check if the result is IResult, Task<IResult> or ValueTask<IResult>. If
// it is, we await if necessary then execute the result.
// Then we check to see if it's Task<object> or ValueTask<object>. If it is, we await
// if necessary and restart the cycle until we've reached a terminal state (unknown type).
// We currently don't handle Task<unknown> or ValueTask<unknown>. We can support this later if this
// ends up being a common scenario.
private static async Task ExecuteObjectReturn(object? obj, HttpContext httpContext)
{
// See if we need to unwrap Task<object> or ValueTask<object>
if (obj is Task<object> taskObj)
{
obj = await taskObj;
}
else if (obj is ValueTask<object> valueTaskObj)
{
obj = await valueTaskObj;
}
else if (obj is Task<IResult?> task)
{
await ExecuteTaskResult(task, httpContext);
return;
}
else if (obj is ValueTask<IResult?> valueTask)
{
await ExecuteValueTaskResult(valueTask, httpContext);
return;
}
else if (obj is Task<string?> taskString)
{
await ExecuteTaskOfString(taskString, httpContext);
return;
}
else if (obj is ValueTask<string?> valueTaskString)
{
await ExecuteValueTaskOfString(valueTaskString, httpContext);
return;
}

// Terminal built ins
if (obj is IResult result)
{
await ExecuteResultWriteResponse(result, httpContext);
}
else if (obj is string stringValue)
{
await httpContext.Response.WriteAsync(stringValue);
}
else
{
// Otherwise, we JSON serialize when we reach the terminal state
await httpContext.Response.WriteAsJsonAsync(obj);
}
}

private static Task ExecuteTask<T>(Task<T> task, HttpContext httpContext)
{
EnsureRequestTaskNotNull(task);
Expand Down Expand Up @@ -715,12 +786,12 @@ private static Task ExecuteValueTaskResult<T>(ValueTask<T?> task, HttpContext ht
{
static async Task ExecuteAwaited(ValueTask<T> task, HttpContext httpContext)
{
await EnsureRequestResultNotNull(await task)!.ExecuteAsync(httpContext);
await EnsureRequestResultNotNull(await task).ExecuteAsync(httpContext);
}

if (task.IsCompletedSuccessfully)
{
return EnsureRequestResultNotNull(task.GetAwaiter().GetResult())!.ExecuteAsync(httpContext);
return EnsureRequestResultNotNull(task.GetAwaiter().GetResult()).ExecuteAsync(httpContext);
}

return ExecuteAwaited(task!, httpContext);
Expand All @@ -730,12 +801,12 @@ private static async Task ExecuteTaskResult<T>(Task<T?> task, HttpContext httpCo
{
EnsureRequestTaskOfNotNull(task);

await EnsureRequestResultNotNull(await task)!.ExecuteAsync(httpContext);
await EnsureRequestResultNotNull(await task).ExecuteAsync(httpContext);
}

private static async Task ExecuteResultWriteResponse(IResult result, HttpContext httpContext)
private static async Task ExecuteResultWriteResponse(IResult? result, HttpContext httpContext)
{
await EnsureRequestResultNotNull(result)!.ExecuteAsync(httpContext);
await EnsureRequestResultNotNull(result).ExecuteAsync(httpContext);
}

private class FactoryContext
Expand Down
46 changes: 45 additions & 1 deletion src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ public void NonStaticTestAction(HttpContext httpContext)

[Fact]
public async Task NonStaticMethodInfoOverloadWorksWithBasicReflection()

{
var methodInfo = typeof(TestNonStaticActionClass).GetMethod(
nameof(TestNonStaticActionClass.NonStaticTestAction),
Expand Down Expand Up @@ -1026,6 +1025,21 @@ public static IEnumerable<object[]> CustomResults
static Task<CustomResult> StaticTaskTestAction() => Task.FromResult(new CustomResult("Still not enough tests!"));
static ValueTask<CustomResult> StaticValueTaskTestAction() => ValueTask.FromResult(new CustomResult("Still not enough tests!"));

// Object return type where the object is IResult
static object StaticResultAsObject() => new CustomResult("Still not enough tests!");
static object StaticResultAsTaskObject() => Task.FromResult<object>(new CustomResult("Still not enough tests!"));
static object StaticResultAsValueTaskObject() => ValueTask.FromResult<object>(new CustomResult("Still not enough tests!"));

// Object return type where the object is Task<IResult>
static object StaticResultAsTaskIResult() => Task.FromResult<IResult>(new CustomResult("Still not enough tests!"));

// Object return type where the object is ValueTask<IResult>
static object StaticResultAsValueTaskIResult() => ValueTask.FromResult<IResult>(new CustomResult("Still not enough tests!"));

// Task<object> return type
static Task<object> StaticTaskOfIResultAsObject() => Task.FromResult<object>(new CustomResult("Still not enough tests!"));
static ValueTask<object> StaticValueTaskOfIResultAsObject() => ValueTask.FromResult<object>(new CustomResult("Still not enough tests!"));

return new List<object[]>
{
new object[] { (Func<CustomResult>)TestAction },
Expand All @@ -1034,6 +1048,16 @@ public static IEnumerable<object[]> CustomResults
new object[] { (Func<CustomResult>)StaticTestAction},
new object[] { (Func<Task<CustomResult>>)StaticTaskTestAction},
new object[] { (Func<ValueTask<CustomResult>>)StaticValueTaskTestAction},

new object[] { (Func<object>)StaticResultAsObject},
new object[] { (Func<object>)StaticResultAsTaskObject},
new object[] { (Func<object>)StaticResultAsValueTaskObject},

new object[] { (Func<object>)StaticResultAsTaskIResult},
new object[] { (Func<object>)StaticResultAsValueTaskIResult},

new object[] { (Func<Task<object>>)StaticTaskOfIResultAsObject},
new object[] { (Func<ValueTask<object>>)StaticValueTaskOfIResultAsObject},
};
}
}
Expand Down Expand Up @@ -1069,6 +1093,17 @@ public static IEnumerable<object[]> StringResult
static Task<string> StaticTaskTestAction() => Task.FromResult("String Test");
static ValueTask<string> StaticValueTaskTestAction() => ValueTask.FromResult("String Test");

// Dynamic via object
static object StaticStringAsObjectTestAction() => "String Test";
static object StaticTaskStringAsObjectTestAction() => Task.FromResult("String Test");
static object StaticValueTaskStringAsObjectTestAction() => ValueTask.FromResult("String Test");

// Dynamic via Task<object>
static Task<object> StaticStringAsTaskObjectTestAction() => Task.FromResult<object>("String Test");

// Dynamic via ValueTask<object>
static ValueTask<object> StaticStringAsValueTaskObjectTestAction() => ValueTask.FromResult<object>("String Test");

return new List<object[]>
{
new object[] { (Func<string>)TestAction },
Expand All @@ -1077,6 +1112,15 @@ public static IEnumerable<object[]> StringResult
new object[] { (Func<string>)StaticTestAction },
new object[] { (Func<Task<string>>)StaticTaskTestAction },
new object[] { (Func<ValueTask<string>>)StaticValueTaskTestAction },

new object[] { (Func<object>)StaticStringAsObjectTestAction },
new object[] { (Func<object>)StaticTaskStringAsObjectTestAction },
new object[] { (Func<object>)StaticValueTaskStringAsObjectTestAction },

new object[] { (Func<Task<object>>)StaticStringAsTaskObjectTestAction },
new object[] { (Func<ValueTask<object>>)StaticStringAsValueTaskObjectTestAction },


};
}
}
Expand Down