Skip to content

[6.2] IRGen: fix failing unconditional class casts #81631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions lib/IRGen/GenCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF,
}

if (llvm::Value *fastResult = emitFastClassCastIfPossible(
IGF, instance, sourceFormalType, targetFormalType,
IGF, instance, sourceFormalType, targetFormalType, mode,
sourceWrappedInOptional, nilCheckBB, nilMergeBB)) {
Explosion fastExplosion;
fastExplosion.add(fastResult);
Expand All @@ -1054,7 +1054,7 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF,
/// not required that the metadata is fully initialized.
llvm::Value *irgen::emitFastClassCastIfPossible(
IRGenFunction &IGF, llvm::Value *instance, CanType sourceFormalType,
CanType targetFormalType, bool sourceWrappedInOptional,
CanType targetFormalType, CheckedCastMode mode, bool sourceWrappedInOptional,
llvm::BasicBlock *&nilCheckBB, llvm::BasicBlock *&nilMergeBB) {
if (!doesCastPreserveOwnershipForTypes(IGF.IGM.getSILModule(),
sourceFormalType, targetFormalType)) {
Expand Down Expand Up @@ -1089,15 +1089,18 @@ llvm::Value *irgen::emitFastClassCastIfPossible(

// If the source was originally wrapped in an Optional, check it for nil now.
if (sourceWrappedInOptional) {
auto isNotNil = IGF.Builder.CreateICmpNE(
auto isNil = IGF.Builder.CreateICmpEQ(
instance, llvm::ConstantPointerNull::get(
cast<llvm::PointerType>(instance->getType())));
auto *isNotNilContBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilMergeBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilCheckBB = IGF.Builder.GetInsertBlock();
IGF.Builder.CreateCondBr(isNotNil, isNotNilContBB, nilMergeBB);

IGF.Builder.emitBlock(isNotNilContBB);
if (mode == CheckedCastMode::Unconditional) {
IGF.emitConditionalTrap(isNil, "Unexpectedly found nil while unwrapping an Optional value");
} else {
auto *isNotNilContBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilMergeBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilCheckBB = IGF.Builder.GetInsertBlock();
IGF.Builder.CreateCondBr(isNil, nilMergeBB, isNotNilContBB);
IGF.Builder.emitBlock(isNotNilContBB);
}
}

// Get the metadata pointer of the destination class type.
Expand All @@ -1121,11 +1124,15 @@ llvm::Value *irgen::emitFastClassCastIfPossible(
llvm::Value *rhs = IGF.Builder.CreateBitCast(objMetadata, IGF.IGM.Int8PtrTy);

// return isa_ptr == metadata_ptr ? instance : nullptr
llvm::Value *isEqual = IGF.Builder.CreateCmp(llvm::CmpInst::Predicate::ICMP_EQ,
llvm::Value *isNotEqual = IGF.Builder.CreateCmp(llvm::CmpInst::Predicate::ICMP_NE,
lhs, rhs);
if (mode == CheckedCastMode::Unconditional) {
IGF.emitConditionalTrap(isNotEqual, "Unconditional cast failed");
return instance;
}
auto *instanceTy = cast<llvm::PointerType>(instance->getType());
auto *nullPtr = llvm::ConstantPointerNull::get(instanceTy);
auto *select = IGF.Builder.CreateSelect(isEqual, instance, nullPtr);
auto *select = IGF.Builder.CreateSelect(isNotEqual, nullPtr, instance);
llvm::Type *destTy = IGF.getTypeInfoForUnlowered(targetFormalType).getStorageType();
return IGF.Builder.CreateBitCast(select, destTy);
}
2 changes: 1 addition & 1 deletion lib/IRGen/GenCast.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace irgen {

llvm::Value *emitFastClassCastIfPossible(
IRGenFunction &IGF, llvm::Value *instance, CanType sourceFormalType,
CanType targetFormalType, bool sourceWrappedInOptional,
CanType targetFormalType, CheckedCastMode mode, bool sourceWrappedInOptional,
llvm::BasicBlock *&nilCheckBB, llvm::BasicBlock *&nilMergeBB);

/// Convert a class object to the given destination type,
Expand Down
38 changes: 38 additions & 0 deletions lib/IRGen/IRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ IRGenFunction::IRGenFunction(IRGenModule &IGM, llvm::Function *Fn,
}

IRGenFunction::~IRGenFunction() {
// Move the trap basic blocks to the end of the function.
for (auto *FailBB : FailBBs) {
CurFn->splice(CurFn->end(), CurFn, FailBB->getIterator());
}

emitEpilogue();

// Restore the debug location.
Expand Down Expand Up @@ -543,6 +548,39 @@ void IRGenFunction::emitTrap(StringRef failureMessage, bool EmitUnreachable) {
Builder.CreateUnreachable();
}

void IRGenFunction::emitConditionalTrap(llvm::Value *condition, StringRef failureMessage,
const SILDebugScope *debugScope) {
// The condition should be false, or we die.
auto expectedCond = Builder.CreateExpect(condition,
llvm::ConstantInt::get(IGM.Int1Ty, 0));

// Emit individual fail blocks so that we can map the failure back to a source
// line.
auto origInsertionPoint = Builder.GetInsertBlock();

llvm::BasicBlock *failBB = llvm::BasicBlock::Create(IGM.getLLVMContext());
llvm::BasicBlock *contBB = llvm::BasicBlock::Create(IGM.getLLVMContext());
auto br = Builder.CreateCondBr(expectedCond, failBB, contBB);

if (IGM.getOptions().AnnotateCondFailMessage && !failureMessage.empty())
br->addAnnotationMetadata(failureMessage);

Builder.SetInsertPoint(&CurFn->back());
Builder.emitBlock(failBB);
if (IGM.DebugInfo && debugScope) {
// If we are emitting DWARF, this does nothing. Otherwise the ``llvm.trap``
// instruction emitted from ``Builtin.condfail`` should have an inlined
// debug location. This is because zero is not an artificial line location
// in CodeView.
IGM.DebugInfo->setInlinedTrapLocation(Builder, debugScope);
}
emitTrap(failureMessage, /*EmitUnreachable=*/true);

Builder.SetInsertPoint(origInsertionPoint);
Builder.emitBlock(contBB);
FailBBs.push_back(failBB);
}

Address IRGenFunction::emitTaskAlloc(llvm::Value *size, Alignment alignment) {
auto *call = Builder.CreateCall(IGM.getTaskAllocFunctionPointer(), {size});
call->setDoesNotThrow();
Expand Down
6 changes: 6 additions & 0 deletions lib/IRGen/IRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class IRGenFunction {
OptimizationMode OptMode;
bool isPerformanceConstraint;

// Destination basic blocks for condfail traps.
llvm::SmallVector<llvm::BasicBlock *, 8> FailBBs;

llvm::Function *const CurFn;
ModuleDecl *getSwiftModule() const;
SILModule &getSILModule() const;
Expand Down Expand Up @@ -474,6 +477,9 @@ class IRGenFunction {
/// Emit a non-mergeable trap call, optionally followed by a terminator.
void emitTrap(StringRef failureMessage, bool EmitUnreachable);

void emitConditionalTrap(llvm::Value *condition, StringRef failureMessage,
const SILDebugScope *debugScope = nullptr);

/// Given at least a src address to a list of elements, runs body over each
/// element passing its address. An optional destination address can be
/// provided which this will run over as well to perform things like
Expand Down
45 changes: 1 addition & 44 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,6 @@ class IRGenSILFunction :

llvm::MapVector<SILBasicBlock *, LoweredBB> LoweredBBs;

// Destination basic blocks for condfail traps.
llvm::SmallVector<llvm::BasicBlock *, 8> FailBBs;

SILFunction *CurSILFn;
// If valid, the address by means of which a return--which is direct in
// SIL--is passed indirectly in IR. Such indirection is necessary when the
Expand Down Expand Up @@ -1190,15 +1187,6 @@ class IRGenSILFunction :
ArtificialKind::RealValue, DbgInstrKind);
}

void emitFailBB() {
if (!FailBBs.empty()) {
// Move the trap basic blocks to the end of the function.
for (auto *FailBB : FailBBs) {
CurFn->splice(CurFn->end(), CurFn, FailBB->getIterator());
}
}
}

//===--------------------------------------------------------------------===//
// SIL instruction lowering
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -1966,9 +1954,6 @@ IRGenSILFunction::IRGenSILFunction(IRGenModule &IGM, SILFunction *f)

IRGenSILFunction::~IRGenSILFunction() {
assert(Builder.hasPostTerminatorIP() && "did not terminate BB?!");
// Emit the fail BB if we have one.
if (!FailBBs.empty())
emitFailBB();
LLVM_DEBUG(CurFn->print(llvm::dbgs()));
}

Expand Down Expand Up @@ -8283,35 +8268,7 @@ void IRGenSILFunction::visitDestroyAddrInst(swift::DestroyAddrInst *i) {
void IRGenSILFunction::visitCondFailInst(swift::CondFailInst *i) {
Explosion e = getLoweredExplosion(i->getOperand());
llvm::Value *cond = e.claimNext();

// The condition should be false, or we die.
auto expectedCond = Builder.CreateExpect(cond,
llvm::ConstantInt::get(IGM.Int1Ty, 0));

// Emit individual fail blocks so that we can map the failure back to a source
// line.
auto origInsertionPoint = Builder.GetInsertBlock();

llvm::BasicBlock *failBB = llvm::BasicBlock::Create(IGM.getLLVMContext());
llvm::BasicBlock *contBB = llvm::BasicBlock::Create(IGM.getLLVMContext());
auto br = Builder.CreateCondBr(expectedCond, failBB, contBB);

if (IGM.getOptions().AnnotateCondFailMessage && !i->getMessage().empty())
br->addAnnotationMetadata(i->getMessage());

Builder.SetInsertPoint(&CurFn->back());
Builder.emitBlock(failBB);
if (IGM.DebugInfo)
// If we are emitting DWARF, this does nothing. Otherwise the ``llvm.trap``
// instruction emitted from ``Builtin.condfail`` should have an inlined
// debug location. This is because zero is not an artificial line location
// in CodeView.
IGM.DebugInfo->setInlinedTrapLocation(Builder, i->getDebugScope());
emitTrap(i->getMessage(), /*EmitUnreachable=*/true);

Builder.SetInsertPoint(origInsertionPoint);
Builder.emitBlock(contBB);
FailBBs.push_back(failBB);
emitConditionalTrap(cond, i->getMessage(), i->getDebugScope());
}

void IRGenSILFunction::visitIncrementProfilerCounterInst(
Expand Down
31 changes: 31 additions & 0 deletions test/Casting/CastTraps.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,35 @@ CastTrapsTestSuite.test("Unexpected Obj-C null")
}
#endif

class Base {}
final class Derived: Base {}
final class Other: Base {}

@inline(never)
func getDerived(_ v: Base) -> Derived {
return v as! Derived
}

@inline(never)
func getDerivedFromOptional(_ v: Base?) -> Derived {
return v as! Derived
}

CastTrapsTestSuite.test("unconditinal fast class cast") {
let c = Other()
expectCrashLater()
_ = getDerived(c)
}

CastTrapsTestSuite.test("unconditinal optional fast class cast") {
let c = Other()
expectCrashLater()
_ = getDerivedFromOptional(c)
}

CastTrapsTestSuite.test("unconditinal optional nil fast class cast") {
expectCrashLater()
_ = getDerivedFromOptional(nil)
}

runAllTests()
13 changes: 12 additions & 1 deletion test/Casting/fast_class_casts.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ func unconditionalCastToFinal(_ b: Classes.Base) -> Classes.Final {
return b as! Classes.Final
}

// CHECK-LABEL: define {{.*}} @"$s4Main32unconditionalOptionalCastToFinaly7Classes0F0CAC4BaseCSgF"
// CHECK-NOT: call {{.*}}@object_getClass
// CHECK-NOT: @swift_dynamicCastClass
// CHECK: }
@inline(never)
func unconditionalOptionalCastToFinal(_ b: Classes.Base?) -> Classes.Final {
return b as! Classes.Final
}

// CHECK-LABEL: define {{.*}} @"$s4Main20castToResilientFinaly0D7Classes0E0CSgAC4BaseCF"
// CHECK: @swift_dynamicCastClass
// CHECK: }
Expand Down Expand Up @@ -132,7 +141,9 @@ func test() {
// CHECK-OUTPUT: Optional(Classes.Final)
print(castToFinal(Classes.Final()) as Any)
// CHECK-OUTPUT: Classes.Final
print(unconditionalCastToFinal(Classes.Final()) as Any)
print(unconditionalCastToFinal(Classes.Final()))
// CHECK-OUTPUT: Classes.Final
print(unconditionalOptionalCastToFinal(Classes.Final()))

// CHECK-OUTPUT: nil
print(castToResilientFinal(ResilientClasses.Base()) as Any)
Expand Down
40 changes: 39 additions & 1 deletion test/IRGen/casts.sil
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: %target-swift-frontend %s -emit-ir -enable-objc-interop -disable-objc-attr-requires-foundation-module | %FileCheck %s -DINT=i%target-ptrsize

// REQUIRES: CPU=i386 || CPU=x86_64
// REQUIRES: CPU=i386 || CPU=x86_64 || CPU=arm64

sil_stage canonical

Expand All @@ -11,9 +11,11 @@ struct NotClass {}

class A {}
class B: A {}
final class F: A {}

sil_vtable A {}
sil_vtable B {}
sil_vtable F {}

// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc ptr @unchecked_addr_cast(ptr noalias nocapture dereferenceable({{.*}}) %0) {{.*}} {
sil @unchecked_addr_cast : $(@in A) -> B {
Expand Down Expand Up @@ -115,6 +117,42 @@ entry(%a : $@thick Any.Type):
return %p : $@thick (CP & OP & CP2).Type
}

// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc ptr @unconditional_fast_class_cast(ptr %0)
// CHECK: [[ISA:%.*]] = load ptr, ptr %0
// CHECK: [[NE:%.*]] = icmp ne {{.*}}, [[ISA]]
// CHECK: [[E:%.*]] = call i1 @llvm.expect.i1(i1 [[NE]], i1 false)
// CHECK: br i1 [[E]], label %[[TRAPBB:[0-9]*]], label %[[RETBB:[0-9]*]]
// CHECK: [[RETBB]]:
// CHECK-NEXT: ret ptr %0
// CHECK: [[TRAPBB]]:
// CHECK-NEXT: call void @llvm.trap()
sil @unconditional_fast_class_cast : $@convention(thin) (@owned A) -> @owned F {
entry(%0 : $A):
%1 = unconditional_checked_cast %0 to F
return %1
}

// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc ptr @unconditional_optional_fast_class_cast(i64 %0)
// CHECK: [[PTR:%.*]] = inttoptr i64 %0 to ptr
// CHECK: [[ISNULL:%.*]] = icmp eq ptr [[PTR]], null
// CHECK: [[ENN:%.*]] = call i1 @llvm.expect.i1(i1 [[ISNULL]], i1 false)
// CHECK: br i1 [[ENN]], label %[[NULLTRAPBB:[0-9]*]], label %[[CONTBB:[0-9]*]]
// CHECK: [[CONTBB]]:
// CHECK: [[ISA:%.*]] = load ptr, ptr [[PTR]]
// CHECK: [[NE:%.*]] = icmp ne {{.*}}, [[ISA]]
// CHECK: [[E:%.*]] = call i1 @llvm.expect.i1(i1 [[NE]], i1 false)
// CHECK: br i1 [[E]], label %[[TRAPBB:[0-9]*]], label %[[RETBB:[0-9]*]]
// CHECK: [[RETBB]]:
// CHECK-NEXT: ret ptr [[PTR]]
// CHECK: [[NULLTRAPBB]]:
// CHECK-NEXT: call void @llvm.trap()
// CHECK: [[TRAPBB]]:
// CHECK-NEXT: call void @llvm.trap()
sil @unconditional_optional_fast_class_cast : $@convention(thin) (@owned Optional<A>) -> @owned F {
entry(%0 : $Optional<A>):
%1 = unconditional_checked_cast %0 to F
return %1
}

// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc { ptr, ptr } @c_cast_to_class_existential(ptr %0)
// CHECK: call { ptr, ptr } @dynamic_cast_existential_1_conditional(ptr {{.*}}, ptr %.Type, {{.*}} @"$s5casts2CPMp"
Expand Down