Skip to content

[SIL Opaque Value] Support checked_cast_br in Address Lowering with loadable source & opaque target type #58441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 118 additions & 92 deletions lib/SILOptimizer/Mandatory/AddressLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,11 @@ struct AddressLoweringState {
// parameters are rewritten.
SmallBlotSetVector<FullApplySite, 16> indirectApplies;

// checked_cast_br instructions with loadable source type and opaque target
// type need to be rewritten in a post-pass, once all the uses of the opaque
// target value are rewritten to their address forms.
SmallVector<CheckedCastBranchInst *, 8> opaqueResultCCBs;

// All function-exiting terminators (return or throw instructions).
SmallVector<TermInst *, 8> exitingInsts;

Expand Down Expand Up @@ -606,6 +611,15 @@ void OpaqueValueVisitor::mapValueStorage() {
if (auto apply = FullApplySite::isa(&inst))
checkForIndirectApply(apply);

// Collect all checked_cast_br instructions that have a loadable source
// type and opaque target type
if (auto *ccb = dyn_cast<CheckedCastBranchInst>(&inst)) {
if (!ccb->getSourceLoweredType().isAddressOnly(*ccb->getFunction()) &&
ccb->getTargetLoweredType().isAddressOnly(*ccb->getFunction())) {
pass.opaqueResultCCBs.push_back(ccb);
}
}

for (auto result : inst.getResults()) {
if (isPseudoCallResult(result) || isPseudoReturnValue(result))
continue;
Expand Down Expand Up @@ -2252,6 +2266,99 @@ void ApplyRewriter::replaceDirectResults(DestructureTupleInst *oldDestructure) {
}
}

//===----------------------------------------------------------------------===//
// CheckedCastBrRewriter
//
// Utilities for rewriting checked_cast_br with opaque source/target type
// ===---------------------------------------------------------------------===//
class CheckedCastBrRewriter {
CheckedCastBranchInst *ccb;
AddressLoweringState &pass;
SILLocation castLoc;
SILFunction *func;
SILBasicBlock *successBB;
SILBasicBlock *failureBB;
SILArgument *origSuccessVal;
SILArgument *origFailureVal;
SILBuilder termBuilder;
SILBuilder successBuilder;
SILBuilder failureBuilder;

public:
CheckedCastBrRewriter(CheckedCastBranchInst *ccb, AddressLoweringState &pass)
: ccb(ccb), pass(pass), castLoc(ccb->getLoc()), func(ccb->getFunction()),
successBB(ccb->getSuccessBB()), failureBB(ccb->getFailureBB()),
origSuccessVal(successBB->getArgument(0)),
origFailureVal(failureBB->getArgument(0)),
termBuilder(pass.getTermBuilder(ccb)),
successBuilder(pass.getBuilder(successBB->begin())),
failureBuilder(pass.getBuilder(failureBB->begin())) {}

/// Rewrite checked_cast_br with opaque source/target operands to
/// checked_cast_addr_br
void rewrite() {
auto srcAddr =
getAddressForCastEntity(ccb->getOperand(), /* needsInit */ true);
auto destAddr =
getAddressForCastEntity(origSuccessVal, /* needsInit */ false);

// getReusedStorageOperand() ensured we do not allocate a separate address
// for failure block arg. Set the storage address of the failure block arg
// to be source address here.
if (origFailureVal->getType().isAddressOnly(*func)) {
pass.valueStorageMap.setStorageAddress(origFailureVal, srcAddr);
}

termBuilder.createCheckedCastAddrBranch(
castLoc, CastConsumptionKind::TakeOnSuccess, srcAddr,
ccb->getSourceFormalType(), destAddr, ccb->getTargetFormalType(),
successBB, failureBB, ccb->getTrueBBCount(), ccb->getFalseBBCount());

replaceBlockArg(origSuccessVal, destAddr);
replaceBlockArg(origFailureVal, srcAddr);

pass.deleter.forceDelete(ccb);
}

private:
/// Return the storageAddress if \p value is opaque, otherwise create and
/// return a stack temporary.
SILValue getAddressForCastEntity(SILValue value, bool needsInit) {
if (value->getType().isAddressOnly(*func))
return pass.valueStorageMap.getStorage(value).storageAddress;

// Create a stack temporary for a loadable value
auto *addr = termBuilder.createAllocStack(castLoc, value->getType());
if (needsInit) {
termBuilder.createStore(castLoc, value, addr,
value->getType().isTrivial(*func)
? StoreOwnershipQualifier::Trivial
: StoreOwnershipQualifier::Init);
}
successBuilder.createDeallocStack(castLoc, addr);
failureBuilder.createDeallocStack(castLoc, addr);
return addr;
}

void replaceBlockArg(SILArgument *blockArg, SILValue addr) {
// Replace all uses of the opaque block arg with a load from its
// storage address.
auto load =
pass.getBuilder(blockArg->getParent()->begin())
.createTrivialLoadOr(castLoc, addr, LoadOwnershipQualifier::Take);
blockArg->replaceAllUsesWith(load);

blockArg->getParent()->eraseArgument(blockArg->getIndex());

if (blockArg->getType().isAddressOnly(*func)) {
// In case of opaque block arg, replace the block arg with the dummy load
// in the valueStorageMap. DefRewriter::visitLoadInst will then rewrite
// the dummy load to copy_addr.
pass.valueStorageMap.replaceValue(blockArg, load);
}
}
};

//===----------------------------------------------------------------------===//
// ReturnRewriter
//
Expand Down Expand Up @@ -2811,87 +2918,8 @@ void UseRewriter::visitSwitchEnumInst(SwitchEnumInst * switchEnum) {
defaultCounter);
}

void UseRewriter::visitCheckedCastBranchInst(
CheckedCastBranchInst *checkedCastBranch) {
auto loc = checkedCastBranch->getLoc();
auto *func = checkedCastBranch->getFunction();
auto *successBB = checkedCastBranch->getSuccessBB();
auto *failureBB = checkedCastBranch->getFailureBB();
auto *oldSuccessVal = successBB->getArgument(0);
auto *oldFailureVal = failureBB->getArgument(0);
auto termBuilder = pass.getTermBuilder(checkedCastBranch);
auto successBuilder = pass.getBuilder(successBB->begin());
auto failureBuilder = pass.getBuilder(failureBB->begin());
bool isAddressOnlyTarget = oldSuccessVal->getType().isAddressOnly(*func);

auto srcAddr = pass.valueStorageMap.getStorage(use->get()).storageAddress;

if (isAddressOnlyTarget) {
// If target is opaque, use the storage address mapped to success
// block's argument as the destination for checked_cast_addr_br.
SILValue destAddr =
pass.valueStorageMap.getStorage(oldSuccessVal).storageAddress;

termBuilder.createCheckedCastAddrBranch(
loc, CastConsumptionKind::TakeOnSuccess, srcAddr,
checkedCastBranch->getSourceFormalType(), destAddr,
checkedCastBranch->getTargetFormalType(), successBB, failureBB,
checkedCastBranch->getTrueBBCount(),
checkedCastBranch->getFalseBBCount());

// In this case, since both success and failure block's args are opaque,
// create dummy loads from their storage addresses that will later be
// rewritten to copy_addr in DefRewriter::visitLoadInst
auto newSuccessVal = successBuilder.createTrivialLoadOr(
loc, destAddr, LoadOwnershipQualifier::Take);
oldSuccessVal->replaceAllUsesWith(newSuccessVal);
successBB->eraseArgument(0);

pass.valueStorageMap.replaceValue(oldSuccessVal, newSuccessVal);

auto newFailureVal = failureBuilder.createTrivialLoadOr(
loc, srcAddr, LoadOwnershipQualifier::Take);
oldFailureVal->replaceAllUsesWith(newFailureVal);
failureBB->eraseArgument(0);

pass.valueStorageMap.replaceValue(oldFailureVal, newFailureVal);
markRewritten(newFailureVal, srcAddr);
} else {
// If the target is loadable, create a stack temporary to be used as the
// destination for checked_cast_addr_br.
SILValue destAddr = termBuilder.createAllocStack(
loc, checkedCastBranch->getTargetLoweredType());

termBuilder.createCheckedCastAddrBranch(
loc, CastConsumptionKind::TakeOnSuccess, srcAddr,
checkedCastBranch->getSourceFormalType(), destAddr,
checkedCastBranch->getTargetFormalType(), successBB, failureBB,
checkedCastBranch->getTrueBBCount(),
checkedCastBranch->getFalseBBCount());

// Replace the success block arg with loaded value from destAddr, and delete
// the success block arg.
auto newSuccessVal = successBuilder.createTrivialLoadOr(
loc, destAddr, LoadOwnershipQualifier::Take);
oldSuccessVal->replaceAllUsesWith(newSuccessVal);
successBB->eraseArgument(0);

successBuilder.createDeallocStack(loc, destAddr);
failureBuilder.createDeallocStack(loc, destAddr);

// Since failure block arg is opaque, create dummy load from its storage
// address. This will be replaced later with copy_addr in
// DefRewriter::visitLoadInst.
auto newFailureVal = failureBuilder.createTrivialLoadOr(
loc, srcAddr, LoadOwnershipQualifier::Take);
oldFailureVal->replaceAllUsesWith(newFailureVal);
failureBB->eraseArgument(0);

pass.valueStorageMap.replaceValue(oldFailureVal, newFailureVal);
markRewritten(newFailureVal, srcAddr);
}

pass.deleter.forceDelete(checkedCastBranch);
void UseRewriter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccb) {
CheckedCastBrRewriter(ccb, pass).rewrite();
}

void UseRewriter::visitUncheckedEnumDataInst(
Expand Down Expand Up @@ -2989,18 +3017,9 @@ class DefRewriter : SILInstructionVisitor<DefRewriter> {
LLVM_DEBUG(llvm::dbgs() << "REWRITE ARG "; arg->dump());
if (storage.storageAddress)
LLVM_DEBUG(llvm::dbgs() << " STORAGE "; storage.storageAddress->dump());

storage.storageAddress = addrMat.materializeAddress(arg);
}

void setStorageAddress(SILValue oldValue, SILValue addr) {
auto &storage = pass.valueStorageMap.getStorage(oldValue);
// getReusedStorageOperand() ensures that oldValue does not already have
// separate storage. So there's no need to delete its alloc_stack.
assert(!storage.storageAddress || storage.storageAddress == addr);
storage.storageAddress = addr;
}

void beforeVisit(SILInstruction *inst) {
LLVM_DEBUG(llvm::dbgs() << "REWRITE DEF "; inst->dump());
if (storage.storageAddress)
Expand Down Expand Up @@ -3068,7 +3087,7 @@ class DefRewriter : SILInstructionVisitor<DefRewriter> {
openExistentialBoxValue->getType().getAddressType());

openExistentialBoxValue->replaceAllTypeDependentUsesWith(openAddr);
setStorageAddress(openExistentialBoxValue, openAddr);
pass.valueStorageMap.setStorageAddress(openExistentialBoxValue, openAddr);
}

// Load an opaque value.
Expand Down Expand Up @@ -3133,7 +3152,7 @@ class DefRewriter : SILInstructionVisitor<DefRewriter> {
// Rewrite Opaque Values
//===----------------------------------------------------------------------===//

// Rewrite applies with indirect paramters or results of loadable types which
// Rewrite applies with indirect parameters or results of loadable types which
// were not visited during opaque value rewritting.
static void rewriteIndirectApply(FullApplySite apply,
AddressLoweringState &pass) {
Expand Down Expand Up @@ -3186,6 +3205,13 @@ static void rewriteFunction(AddressLoweringState &pass) {
rewriteIndirectApply(optionalApply.getValue(), pass);
}
}

// Rewrite all checked_cast_br instructions with loadable source type and
// opaque target type now
for (auto *ccb : pass.opaqueResultCCBs) {
CheckedCastBrRewriter(ccb, pass).rewrite();
}

// Rewrite this function's return value now that all opaque values within the
// function are rewritten. This still depends on a valid ValueStorage
// projection operands.
Expand Down
6 changes: 6 additions & 0 deletions lib/SILOptimizer/Mandatory/AddressLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ class ValueStorageMap {
return getNonEnumBaseStorage(getStorage(value));
}

void setStorageAddress(SILValue value, SILValue addr) {
auto &storage = getStorage(value);
assert(!storage.storageAddress || storage.storageAddress == addr);
storage.storageAddress = addr;
}

/// Insert a value in the map, creating a ValueStorage object for it. This
/// must be called in RPO order.
void insertValue(SILValue value, SILValue storageAddress);
Expand Down
41 changes: 41 additions & 0 deletions test/SILOptimizer/address_lowering.sil
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,47 @@ bb3:
return %31 : $()
}

sil @use_Any : $@convention(thin) (@in Any) -> ()

// CHECK-LABEL: sil [ossa] @test_checked_cast_br3 : $@convention(method) (@owned C) -> () {
// CHECK: bb0(%0 : @owned $C):
// CHECK: [[DST:%.*]] = alloc_stack $Any
// CHECK: [[SRC_TMP:%.*]] = alloc_stack $C
// CHECK: store %0 to [init] [[SRC_TMP]] : $*C
// CHECK: checked_cast_addr_br take_on_success C in [[SRC_TMP]] : $*C to Any in [[DST]] : $*Any, bb2, bb1
// CHECK: bb1:
// CHECK: [[LD:%.*]] = load [take] [[SRC_TMP]] : $*C
// CHECK: dealloc_stack [[SRC_TMP]] : $*C
// CHECK: destroy_value [[LD]] : $C
// CHECK: br bb3
// CHECK: bb2:
// CHECK: dealloc_stack [[SRC_TMP]] : $*C
// CHECK: [[FUNC:%.*]] = function_ref @use_Any : $@convention(thin) (@in Any) -> ()
// CHECK: apply [[FUNC]]([[DST]]) : $@convention(thin) (@in Any) -> ()
// CHECK: br bb3
// CHECK: bb3:
// CHECK: [[RES:%.*]] = tuple ()
// CHECK: dealloc_stack [[DST]] : $*Any
// CHECK: return [[RES]] : $()
// CHECK: }
sil [ossa] @test_checked_cast_br3 : $@convention(method) (@owned C) -> () {
bb0(%0 : @owned $C):
checked_cast_br %0 : $C to Any, bb1, bb2

bb1(%3 : @owned $Any):
%f = function_ref @use_Any : $@convention(thin) (@in Any) -> ()
%call = apply %f(%3) : $@convention(thin) (@in Any) -> ()
br bb3

bb2(%4 : @owned $C):
destroy_value %4 : $C
br bb3

bb3:
%31 = tuple ()
return %31 : $()
}

// CHECK-LABEL: sil hidden [ossa] @test_unchecked_bitwise_cast :
// CHECK: bb0(%0 : $*U, %1 : $*T, %2 : $@thick U.Type):
// CHECK: [[STK:%.*]] = alloc_stack $T
Expand Down