Skip to content

Commit 1f3bff3

Browse files
committed
IRGen: fix failing unconditional class casts
When unconditionally casting from a base to a final derived class, e.g. `base as! Derived`, the program did not abort with a trap. Instead the resulting null-pointer caused a crash later in the program. This fix inserts a trap condition for the failing case of such a cast. rdar://151462303
1 parent 442db1b commit 1f3bff3

File tree

5 files changed

+101
-14
lines changed

5 files changed

+101
-14
lines changed

lib/IRGen/GenCast.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,7 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF,
10321032
}
10331033

10341034
if (llvm::Value *fastResult = emitFastClassCastIfPossible(
1035-
IGF, instance, sourceFormalType, targetFormalType,
1035+
IGF, instance, sourceFormalType, targetFormalType, mode,
10361036
sourceWrappedInOptional, nilCheckBB, nilMergeBB)) {
10371037
Explosion fastExplosion;
10381038
fastExplosion.add(fastResult);
@@ -1054,7 +1054,7 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF,
10541054
/// not required that the metadata is fully initialized.
10551055
llvm::Value *irgen::emitFastClassCastIfPossible(
10561056
IRGenFunction &IGF, llvm::Value *instance, CanType sourceFormalType,
1057-
CanType targetFormalType, bool sourceWrappedInOptional,
1057+
CanType targetFormalType, CheckedCastMode mode, bool sourceWrappedInOptional,
10581058
llvm::BasicBlock *&nilCheckBB, llvm::BasicBlock *&nilMergeBB) {
10591059
if (!doesCastPreserveOwnershipForTypes(IGF.IGM.getSILModule(),
10601060
sourceFormalType, targetFormalType)) {
@@ -1089,15 +1089,18 @@ llvm::Value *irgen::emitFastClassCastIfPossible(
10891089

10901090
// If the source was originally wrapped in an Optional, check it for nil now.
10911091
if (sourceWrappedInOptional) {
1092-
auto isNotNil = IGF.Builder.CreateICmpNE(
1092+
auto isNil = IGF.Builder.CreateICmpEQ(
10931093
instance, llvm::ConstantPointerNull::get(
10941094
cast<llvm::PointerType>(instance->getType())));
1095-
auto *isNotNilContBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
1096-
nilMergeBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
1097-
nilCheckBB = IGF.Builder.GetInsertBlock();
1098-
IGF.Builder.CreateCondBr(isNotNil, isNotNilContBB, nilMergeBB);
1099-
1100-
IGF.Builder.emitBlock(isNotNilContBB);
1095+
if (mode == CheckedCastMode::Unconditional) {
1096+
IGF.emitConditionalTrap(isNil, "Unexpectedly found nil while unwrapping an Optional value");
1097+
} else {
1098+
auto *isNotNilContBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
1099+
nilMergeBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
1100+
nilCheckBB = IGF.Builder.GetInsertBlock();
1101+
IGF.Builder.CreateCondBr(isNil, nilMergeBB, isNotNilContBB);
1102+
IGF.Builder.emitBlock(isNotNilContBB);
1103+
}
11011104
}
11021105

11031106
// Get the metadata pointer of the destination class type.
@@ -1121,11 +1124,15 @@ llvm::Value *irgen::emitFastClassCastIfPossible(
11211124
llvm::Value *rhs = IGF.Builder.CreateBitCast(objMetadata, IGF.IGM.Int8PtrTy);
11221125

11231126
// return isa_ptr == metadata_ptr ? instance : nullptr
1124-
llvm::Value *isEqual = IGF.Builder.CreateCmp(llvm::CmpInst::Predicate::ICMP_EQ,
1127+
llvm::Value *isNotEqual = IGF.Builder.CreateCmp(llvm::CmpInst::Predicate::ICMP_NE,
11251128
lhs, rhs);
1129+
if (mode == CheckedCastMode::Unconditional) {
1130+
IGF.emitConditionalTrap(isNotEqual, "Unconditional cast failed");
1131+
return instance;
1132+
}
11261133
auto *instanceTy = cast<llvm::PointerType>(instance->getType());
11271134
auto *nullPtr = llvm::ConstantPointerNull::get(instanceTy);
1128-
auto *select = IGF.Builder.CreateSelect(isEqual, instance, nullPtr);
1135+
auto *select = IGF.Builder.CreateSelect(isNotEqual, nullPtr, instance);
11291136
llvm::Type *destTy = IGF.getTypeInfoForUnlowered(targetFormalType).getStorageType();
11301137
return IGF.Builder.CreateBitCast(select, destTy);
11311138
}

lib/IRGen/GenCast.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ namespace irgen {
6161

6262
llvm::Value *emitFastClassCastIfPossible(
6363
IRGenFunction &IGF, llvm::Value *instance, CanType sourceFormalType,
64-
CanType targetFormalType, bool sourceWrappedInOptional,
64+
CanType targetFormalType, CheckedCastMode mode, bool sourceWrappedInOptional,
6565
llvm::BasicBlock *&nilCheckBB, llvm::BasicBlock *&nilMergeBB);
6666

6767
/// Convert a class object to the given destination type,

test/Casting/CastTraps.swift.gyb

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,35 @@ CastTrapsTestSuite.test("Unexpected Obj-C null")
142142
}
143143
#endif
144144

145+
class Base {}
146+
final class Derived: Base {}
147+
final class Other: Base {}
148+
149+
@inline(never)
150+
func getDerived(_ v: Base) -> Derived {
151+
return v as! Derived
152+
}
153+
154+
@inline(never)
155+
func getDerivedFromOptional(_ v: Base?) -> Derived {
156+
return v as! Derived
157+
}
158+
159+
CastTrapsTestSuite.test("unconditinal fast class cast") {
160+
let c = Other()
161+
expectCrashLater()
162+
_ = getDerived(c)
163+
}
164+
165+
CastTrapsTestSuite.test("unconditinal optional fast class cast") {
166+
let c = Other()
167+
expectCrashLater()
168+
_ = getDerivedFromOptional(c)
169+
}
170+
171+
CastTrapsTestSuite.test("unconditinal optional nil fast class cast") {
172+
expectCrashLater()
173+
_ = getDerivedFromOptional(nil)
174+
}
175+
145176
runAllTests()

test/Casting/fast_class_casts.swift

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ func unconditionalCastToFinal(_ b: Classes.Base) -> Classes.Final {
5656
return b as! Classes.Final
5757
}
5858

59+
// CHECK-LABEL: define {{.*}} @"$s4Main32unconditionalOptionalCastToFinaly7Classes0F0CAC4BaseCSgF"
60+
// CHECK-NOT: call {{.*}}@object_getClass
61+
// CHECK-NOT: @swift_dynamicCastClass
62+
// CHECK: }
63+
@inline(never)
64+
func unconditionalOptionalCastToFinal(_ b: Classes.Base?) -> Classes.Final {
65+
return b as! Classes.Final
66+
}
67+
5968
// CHECK-LABEL: define {{.*}} @"$s4Main20castToResilientFinaly0D7Classes0E0CSgAC4BaseCF"
6069
// CHECK: @swift_dynamicCastClass
6170
// CHECK: }
@@ -132,7 +141,9 @@ func test() {
132141
// CHECK-OUTPUT: Optional(Classes.Final)
133142
print(castToFinal(Classes.Final()) as Any)
134143
// CHECK-OUTPUT: Classes.Final
135-
print(unconditionalCastToFinal(Classes.Final()) as Any)
144+
print(unconditionalCastToFinal(Classes.Final()))
145+
// CHECK-OUTPUT: Classes.Final
146+
print(unconditionalOptionalCastToFinal(Classes.Final()))
136147

137148
// CHECK-OUTPUT: nil
138149
print(castToResilientFinal(ResilientClasses.Base()) as Any)

test/IRGen/casts.sil

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: %target-swift-frontend %s -emit-ir -enable-objc-interop -disable-objc-attr-requires-foundation-module | %FileCheck %s -DINT=i%target-ptrsize
22

3-
// REQUIRES: CPU=i386 || CPU=x86_64
3+
// REQUIRES: CPU=i386 || CPU=x86_64 || CPU=arm64
44

55
sil_stage canonical
66

@@ -11,9 +11,11 @@ struct NotClass {}
1111

1212
class A {}
1313
class B: A {}
14+
final class F: A {}
1415

1516
sil_vtable A {}
1617
sil_vtable B {}
18+
sil_vtable F {}
1719

1820
// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc ptr @unchecked_addr_cast(ptr noalias nocapture dereferenceable({{.*}}) %0) {{.*}} {
1921
sil @unchecked_addr_cast : $(@in A) -> B {
@@ -115,6 +117,42 @@ entry(%a : $@thick Any.Type):
115117
return %p : $@thick (CP & OP & CP2).Type
116118
}
117119

120+
// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc ptr @unconditional_fast_class_cast(ptr %0)
121+
// CHECK: [[ISA:%.*]] = load ptr, ptr %0
122+
// CHECK: [[NE:%.*]] = icmp ne {{.*}}, [[ISA]]
123+
// CHECK: [[E:%.*]] = call i1 @llvm.expect.i1(i1 [[NE]], i1 false)
124+
// CHECK: br i1 [[E]], label %[[TRAPBB:[0-9]*]], label %[[RETBB:[0-9]*]]
125+
// CHECK: [[RETBB]]:
126+
// CHECK-NEXT: ret ptr %0
127+
// CHECK: [[TRAPBB]]:
128+
// CHECK-NEXT: call void @llvm.trap()
129+
sil @unconditional_fast_class_cast : $@convention(thin) (@owned A) -> @owned F {
130+
entry(%0 : $A):
131+
%1 = unconditional_checked_cast %0 to F
132+
return %1
133+
}
134+
135+
// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc ptr @unconditional_optional_fast_class_cast(i64 %0)
136+
// CHECK: [[PTR:%.*]] = inttoptr i64 %0 to ptr
137+
// CHECK: [[ISNULL:%.*]] = icmp eq ptr [[PTR]], null
138+
// CHECK: [[ENN:%.*]] = call i1 @llvm.expect.i1(i1 [[ISNULL]], i1 false)
139+
// CHECK: br i1 [[ENN]], label %[[NULLTRAPBB:[0-9]*]], label %[[CONTBB:[0-9]*]]
140+
// CHECK: [[CONTBB]]:
141+
// CHECK: [[ISA:%.*]] = load ptr, ptr [[PTR]]
142+
// CHECK: [[NE:%.*]] = icmp ne {{.*}}, [[ISA]]
143+
// CHECK: [[E:%.*]] = call i1 @llvm.expect.i1(i1 [[NE]], i1 false)
144+
// CHECK: br i1 [[E]], label %[[TRAPBB:[0-9]*]], label %[[RETBB:[0-9]*]]
145+
// CHECK: [[RETBB]]:
146+
// CHECK-NEXT: ret ptr [[PTR]]
147+
// CHECK: [[NULLTRAPBB]]:
148+
// CHECK-NEXT: call void @llvm.trap()
149+
// CHECK: [[TRAPBB]]:
150+
// CHECK-NEXT: call void @llvm.trap()
151+
sil @unconditional_optional_fast_class_cast : $@convention(thin) (@owned Optional<A>) -> @owned F {
152+
entry(%0 : $Optional<A>):
153+
%1 = unconditional_checked_cast %0 to F
154+
return %1
155+
}
118156

119157
// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc { ptr, ptr } @c_cast_to_class_existential(ptr %0)
120158
// CHECK: call { ptr, ptr } @dynamic_cast_existential_1_conditional(ptr {{.*}}, ptr %.Type, {{.*}} @"$s5casts2CPMp"

0 commit comments

Comments
 (0)