Skip to content

IRGen: Fix preservation of error result in async dispatch thunks #35666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions lib/IRGen/GenThunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ void IRGenThunk::prepareArguments() {
}

for (unsigned i = 0, e = asyncLayout->getArgumentCount(); i < e; ++i) {
Address addr = asyncLayout->getArgumentLayout(i).project(
IGF, context, llvm::None);
params.add(IGF.Builder.CreateLoad(addr));
auto layout = asyncLayout->getArgumentLayout(i);
Address addr = layout.project(IGF, context, llvm::None);
auto &ti = cast<LoadableTypeInfo>(layout.getType());
ti.loadAsTake(IGF, addr, params);
}

if (asyncLayout->hasBindings()) {
Expand Down Expand Up @@ -329,8 +330,20 @@ void IRGenThunk::emit() {
emission->emitToExplosion(result, /*isOutlined=*/false);
}

llvm::Value *errorValue = nullptr;

if (isAsync && origTy->hasErrorResult()) {
SILType errorType = conv.getSILErrorType(expansionContext);
Address calleeErrorSlot = emission->getCalleeErrorSlot(errorType);
errorValue = IGF.Builder.CreateLoad(calleeErrorSlot);
}

emission->end();

if (isAsync && errorValue) {
IGF.Builder.CreateStore(errorValue, IGF.getCallerErrorResultSlot());
}

if (isAsync) {
emitAsyncReturn(IGF, *asyncLayout, origTy);
IGF.emitCoroutineOrAsyncExit();
Expand Down
9 changes: 9 additions & 0 deletions test/Concurrency/Runtime/Inputs/resilient_class.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
public enum MyError : Error {
case bad
}

open class BaseClass<T> {
let value: T

Expand All @@ -9,4 +13,9 @@ open class BaseClass<T> {
open func wait() async -> T {
return value
}
open func wait(orThrow: Bool) async throws {
if orThrow {
throw MyError.bad
}
}
}
1 change: 1 addition & 0 deletions test/Concurrency/Runtime/Inputs/resilient_protocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ public protocol Awaitable {
associatedtype Result
func waitForNothing() async
func wait() async -> Result
func wait(orThrow: Bool) async throws
}
15 changes: 13 additions & 2 deletions test/Concurrency/Runtime/class_resilience.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@ class MyDerived : BaseClass<Int> {
override func wait() async -> Int {
return await super.wait() * 2
}

override func wait(orThrow: Bool) async throws {
return try await super.wait(orThrow: orThrow)
}
}

func virtualWaitForNothing<T>(_ c: BaseClass<T>) async {
await c.waitForNothing()
}

func virtualWait<T>(_ t: BaseClass<T>) async -> T {
return await t.wait()
func virtualWait<T>(_ c: BaseClass<T>) async -> T {
return await c.wait()
}

func virtualWait<T>(orThrow: Bool, _ c: BaseClass<T>) async throws {
return try await c.wait(orThrow: orThrow)
}

var AsyncVTableMethodSuite = TestSuite("ResilientClass")
Expand All @@ -43,6 +51,9 @@ AsyncVTableMethodSuite.test("AsyncVTableMethod") {
await virtualWaitForNothing(x)

expectEqual(642, await virtualWait(x))

expectNil(try? await virtualWait(orThrow: true, x))
try! await virtualWait(orThrow: false, x)
}
}

Expand Down
17 changes: 17 additions & 0 deletions test/Concurrency/Runtime/protocol_resilience.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,22 @@
import StdlibUnittest
import resilient_protocol

enum MyError : Error {
case bad
}

struct IntAwaitable : Awaitable {
func waitForNothing() async {}

func wait() async -> Int {
return 123
}

func wait(orThrow: Bool) async throws {
if (orThrow) {
throw MyError.bad
}
}
}

func genericWaitForNothing<T : Awaitable>(_ t: T) async {
Expand All @@ -32,6 +42,10 @@ func genericWait<T : Awaitable>(_ t: T) async -> T.Result {
return await t.wait()
}

func genericWait<T : Awaitable>(orThrow: Bool, _ t: T) async throws {
return try await t.wait(orThrow: orThrow)
}

var AsyncProtocolRequirementSuite = TestSuite("ResilientProtocol")

AsyncProtocolRequirementSuite.test("AsyncProtocolRequirement") {
Expand All @@ -41,6 +55,9 @@ AsyncProtocolRequirementSuite.test("AsyncProtocolRequirement") {
await genericWaitForNothing(x)

expectEqual(123, await genericWait(x))

expectNil(try? await genericWait(orThrow: true, x))
try! await genericWait(orThrow: false, x)
}
}

Expand Down