Skip to content

[IRGen] Apply correct type to conversion for direct values and erorrs… #79369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions lib/IRGen/GenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4534,8 +4534,8 @@ void CallEmission::emitToUnmappedExplosionWithDirectTypedError(

Explosion errorExplosion;
if (!errorSchema.empty()) {
if (auto *structTy =
dyn_cast<llvm::StructType>(errorSchema.getExpandedType(IGF.IGM))) {
auto *expandedType = errorSchema.getExpandedType(IGF.IGM);
if (auto *structTy = dyn_cast<llvm::StructType>(expandedType)) {
for (unsigned i = 0, e = structTy->getNumElements(); i < e; ++i) {
llvm::Value *elt = values[combined.errorValueMapping[i]];
auto *nativeTy = structTy->getElementType(i);
Expand All @@ -4545,7 +4545,7 @@ void CallEmission::emitToUnmappedExplosionWithDirectTypedError(
} else {
auto *converted =
convertForDirectError(IGF, values[combined.errorValueMapping[0]],
combined.combinedTy, /*forExtraction*/ true);
expandedType, /*forExtraction*/ true);
errorExplosion.add(converted);
}

Expand All @@ -4558,17 +4558,17 @@ void CallEmission::emitToUnmappedExplosionWithDirectTypedError(
// If the regular result type is void, there is nothing to explode
if (!nativeSchema.empty()) {
Explosion resultExplosion;
if (auto *structTy =
dyn_cast<llvm::StructType>(nativeSchema.getExpandedType(IGF.IGM))) {
auto *expandedType = nativeSchema.getExpandedType(IGF.IGM);
if (auto *structTy = dyn_cast<llvm::StructType>(expandedType)) {
for (unsigned i = 0, e = structTy->getNumElements(); i < e; ++i) {
auto *nativeTy = structTy->getElementType(i);
auto *converted = convertForDirectError(IGF, values[i], nativeTy,
/*forExtraction*/ true);
resultExplosion.add(converted);
}
} else {
auto *converted = convertForDirectError(
IGF, values[0], combined.combinedTy, /*forExtraction*/ true);
auto *converted = convertForDirectError(IGF, values[0], expandedType,
/*forExtraction*/ true);
resultExplosion.add(converted);
}
out = nativeSchema.mapFromNative(IGF.IGM, IGF, resultExplosion, resultType);
Expand Down
39 changes: 39 additions & 0 deletions test/IRGen/typed_throws.swift
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,42 @@ func callAsyncIndirectResult<A>(p: any AsyncGenProto<A>, x: Int) async throws(Sm
return try await p.fn(arg: x)
}

@inline(never)
func smallResultLargerError() throws(SmallError) -> Int8? {
return 10
}

// CHECK: [[COERCED:%.*]] = alloca { i16 }, align 2
// CHECK: [[RES:%.*]] = call swiftcc i64 @"$s12typed_throws22smallResultLargerErrors4Int8VSgyAA05SmallF0VYKF"(ptr swiftself undef, ptr noalias nocapture swifterror dereferenceable(8) %swifterror)
// CHECK: [[TRUNC:%.*]] = trunc i64 [[RES]] to i16
// CHECK: [[COERCED_PTR:%.*]] = getelementptr inbounds { i16 }, ptr [[COERCED]], i32 0, i32 0
// CHECK: store i16 [[TRUNC]], ptr [[COERCED_PTR]], align 2
func callSmallResultLargerError() {
let res = try! smallResultLargerError()
precondition(res! == 10)
}

enum UInt8OptSingletonError: Error {
case a(Int8?)
}

@inline(never)
func smallErrorLargerResult() throws(UInt8OptSingletonError) -> Int {
throw .a(10)
}

// CHECK: [[COERCED:%.*]] = alloca { i16 }, align 2
// CHECK: [[RES:%.*]] = call swiftcc i64 @"$s12typed_throws22smallErrorLargerResultSiyAA017UInt8OptSingletonD0OYKF"(ptr swiftself undef, ptr noalias nocapture swifterror dereferenceable(8) %swifterror)
// CHECK: [[TRUNC:%.*]] = trunc i64 [[RES]] to i16
// CHECK: [[COERCED_PTR:%.*]] = getelementptr inbounds { i16 }, ptr [[COERCED]], i32 0, i32 0
// CHECK: store i16 [[TRUNC]], ptr [[COERCED_PTR]], align 2
func callSmallErrorLargerResult() {
do {
_ = try smallErrorLargerResult()
} catch {
switch error {
case .a(let x):
precondition(x! == 10)
}
}
}
39 changes: 39 additions & 0 deletions test/Interpreter/typed_throws_abi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,40 @@ func checkAsync() async {
await invoke { try await impl.nonMatching_f1(false) }
}

enum SmallError: Error {
case a(Int)
}

@inline(never)
func smallResultLargerError() throws(SmallError) -> Int8? {
return 10
}

func callSmallResultLargerError() {
let res = try! smallResultLargerError()
print("Result is: \(String(describing: res))")
}

enum UInt8OptSingletonError: Error {
case a(Int8?)
}

@inline(never)
func smallErrorLargerResult() throws(UInt8OptSingletonError) -> Int {
throw .a(10)
}

func callSmallErrorLargerResult() {
do {
_ = try smallErrorLargerResult()
} catch {
switch error {
case .a(let x):
print("Error value is: \(String(describing: x))")
}
}
}

enum MyError: Error {
case x
case y
Expand Down Expand Up @@ -315,5 +349,10 @@ public struct Main {
await checkAsync()
// CHECK: Arg is 10
print(try! await callAsyncIndirectResult(p: AsyncGenProtoImpl(), x: 10))

// CHECK: Result is: Optional(10)
callSmallResultLargerError()
// CHECK: Error value is: Optional(10)
callSmallErrorLargerResult()
}
}