Skip to content

Commit 60d0ac6

Browse files
authored
[SIL Opaque Value] Support checked_cast_br in Address Lowering with loadable source type and opaque target type (#58441)
1 parent 94a4d84 commit 60d0ac6

File tree

3 files changed

+165
-92
lines changed

3 files changed

+165
-92
lines changed

lib/SILOptimizer/Mandatory/AddressLowering.cpp

Lines changed: 118 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,11 @@ struct AddressLoweringState {
419419
// parameters are rewritten.
420420
SmallBlotSetVector<FullApplySite, 16> indirectApplies;
421421

422+
// checked_cast_br instructions with loadable source type and opaque target
423+
// type need to be rewritten in a post-pass, once all the uses of the opaque
424+
// target value are rewritten to their address forms.
425+
SmallVector<CheckedCastBranchInst *, 8> opaqueResultCCBs;
426+
422427
// All function-exiting terminators (return or throw instructions).
423428
SmallVector<TermInst *, 8> exitingInsts;
424429

@@ -606,6 +611,15 @@ void OpaqueValueVisitor::mapValueStorage() {
606611
if (auto apply = FullApplySite::isa(&inst))
607612
checkForIndirectApply(apply);
608613

614+
// Collect all checked_cast_br instructions that have a loadable source
615+
// type and opaque target type
616+
if (auto *ccb = dyn_cast<CheckedCastBranchInst>(&inst)) {
617+
if (!ccb->getSourceLoweredType().isAddressOnly(*ccb->getFunction()) &&
618+
ccb->getTargetLoweredType().isAddressOnly(*ccb->getFunction())) {
619+
pass.opaqueResultCCBs.push_back(ccb);
620+
}
621+
}
622+
609623
for (auto result : inst.getResults()) {
610624
if (isPseudoCallResult(result) || isPseudoReturnValue(result))
611625
continue;
@@ -2252,6 +2266,99 @@ void ApplyRewriter::replaceDirectResults(DestructureTupleInst *oldDestructure) {
22522266
}
22532267
}
22542268

2269+
//===----------------------------------------------------------------------===//
2270+
// CheckedCastBrRewriter
2271+
//
2272+
// Utilities for rewriting checked_cast_br with opaque source/target type
2273+
// ===---------------------------------------------------------------------===//
2274+
class CheckedCastBrRewriter {
2275+
CheckedCastBranchInst *ccb;
2276+
AddressLoweringState &pass;
2277+
SILLocation castLoc;
2278+
SILFunction *func;
2279+
SILBasicBlock *successBB;
2280+
SILBasicBlock *failureBB;
2281+
SILArgument *origSuccessVal;
2282+
SILArgument *origFailureVal;
2283+
SILBuilder termBuilder;
2284+
SILBuilder successBuilder;
2285+
SILBuilder failureBuilder;
2286+
2287+
public:
2288+
CheckedCastBrRewriter(CheckedCastBranchInst *ccb, AddressLoweringState &pass)
2289+
: ccb(ccb), pass(pass), castLoc(ccb->getLoc()), func(ccb->getFunction()),
2290+
successBB(ccb->getSuccessBB()), failureBB(ccb->getFailureBB()),
2291+
origSuccessVal(successBB->getArgument(0)),
2292+
origFailureVal(failureBB->getArgument(0)),
2293+
termBuilder(pass.getTermBuilder(ccb)),
2294+
successBuilder(pass.getBuilder(successBB->begin())),
2295+
failureBuilder(pass.getBuilder(failureBB->begin())) {}
2296+
2297+
/// Rewrite checked_cast_br with opaque source/target operands to
2298+
/// checked_cast_addr_br
2299+
void rewrite() {
2300+
auto srcAddr =
2301+
getAddressForCastEntity(ccb->getOperand(), /* needsInit */ true);
2302+
auto destAddr =
2303+
getAddressForCastEntity(origSuccessVal, /* needsInit */ false);
2304+
2305+
// getReusedStorageOperand() ensured we do not allocate a separate address
2306+
// for failure block arg. Set the storage address of the failure block arg
2307+
// to be source address here.
2308+
if (origFailureVal->getType().isAddressOnly(*func)) {
2309+
pass.valueStorageMap.setStorageAddress(origFailureVal, srcAddr);
2310+
}
2311+
2312+
termBuilder.createCheckedCastAddrBranch(
2313+
castLoc, CastConsumptionKind::TakeOnSuccess, srcAddr,
2314+
ccb->getSourceFormalType(), destAddr, ccb->getTargetFormalType(),
2315+
successBB, failureBB, ccb->getTrueBBCount(), ccb->getFalseBBCount());
2316+
2317+
replaceBlockArg(origSuccessVal, destAddr);
2318+
replaceBlockArg(origFailureVal, srcAddr);
2319+
2320+
pass.deleter.forceDelete(ccb);
2321+
}
2322+
2323+
private:
2324+
/// Return the storageAddress if \p value is opaque, otherwise create and
2325+
/// return a stack temporary.
2326+
SILValue getAddressForCastEntity(SILValue value, bool needsInit) {
2327+
if (value->getType().isAddressOnly(*func))
2328+
return pass.valueStorageMap.getStorage(value).storageAddress;
2329+
2330+
// Create a stack temporary for a loadable value
2331+
auto *addr = termBuilder.createAllocStack(castLoc, value->getType());
2332+
if (needsInit) {
2333+
termBuilder.createStore(castLoc, value, addr,
2334+
value->getType().isTrivial(*func)
2335+
? StoreOwnershipQualifier::Trivial
2336+
: StoreOwnershipQualifier::Init);
2337+
}
2338+
successBuilder.createDeallocStack(castLoc, addr);
2339+
failureBuilder.createDeallocStack(castLoc, addr);
2340+
return addr;
2341+
}
2342+
2343+
void replaceBlockArg(SILArgument *blockArg, SILValue addr) {
2344+
// Replace all uses of the opaque block arg with a load from its
2345+
// storage address.
2346+
auto load =
2347+
pass.getBuilder(blockArg->getParent()->begin())
2348+
.createTrivialLoadOr(castLoc, addr, LoadOwnershipQualifier::Take);
2349+
blockArg->replaceAllUsesWith(load);
2350+
2351+
blockArg->getParent()->eraseArgument(blockArg->getIndex());
2352+
2353+
if (blockArg->getType().isAddressOnly(*func)) {
2354+
// In case of opaque block arg, replace the block arg with the dummy load
2355+
// in the valueStorageMap. DefRewriter::visitLoadInst will then rewrite
2356+
// the dummy load to copy_addr.
2357+
pass.valueStorageMap.replaceValue(blockArg, load);
2358+
}
2359+
}
2360+
};
2361+
22552362
//===----------------------------------------------------------------------===//
22562363
// ReturnRewriter
22572364
//
@@ -2811,87 +2918,8 @@ void UseRewriter::visitSwitchEnumInst(SwitchEnumInst * switchEnum) {
28112918
defaultCounter);
28122919
}
28132920

2814-
void UseRewriter::visitCheckedCastBranchInst(
2815-
CheckedCastBranchInst *checkedCastBranch) {
2816-
auto loc = checkedCastBranch->getLoc();
2817-
auto *func = checkedCastBranch->getFunction();
2818-
auto *successBB = checkedCastBranch->getSuccessBB();
2819-
auto *failureBB = checkedCastBranch->getFailureBB();
2820-
auto *oldSuccessVal = successBB->getArgument(0);
2821-
auto *oldFailureVal = failureBB->getArgument(0);
2822-
auto termBuilder = pass.getTermBuilder(checkedCastBranch);
2823-
auto successBuilder = pass.getBuilder(successBB->begin());
2824-
auto failureBuilder = pass.getBuilder(failureBB->begin());
2825-
bool isAddressOnlyTarget = oldSuccessVal->getType().isAddressOnly(*func);
2826-
2827-
auto srcAddr = pass.valueStorageMap.getStorage(use->get()).storageAddress;
2828-
2829-
if (isAddressOnlyTarget) {
2830-
// If target is opaque, use the storage address mapped to success
2831-
// block's argument as the destination for checked_cast_addr_br.
2832-
SILValue destAddr =
2833-
pass.valueStorageMap.getStorage(oldSuccessVal).storageAddress;
2834-
2835-
termBuilder.createCheckedCastAddrBranch(
2836-
loc, CastConsumptionKind::TakeOnSuccess, srcAddr,
2837-
checkedCastBranch->getSourceFormalType(), destAddr,
2838-
checkedCastBranch->getTargetFormalType(), successBB, failureBB,
2839-
checkedCastBranch->getTrueBBCount(),
2840-
checkedCastBranch->getFalseBBCount());
2841-
2842-
// In this case, since both success and failure block's args are opaque,
2843-
// create dummy loads from their storage addresses that will later be
2844-
// rewritten to copy_addr in DefRewriter::visitLoadInst
2845-
auto newSuccessVal = successBuilder.createTrivialLoadOr(
2846-
loc, destAddr, LoadOwnershipQualifier::Take);
2847-
oldSuccessVal->replaceAllUsesWith(newSuccessVal);
2848-
successBB->eraseArgument(0);
2849-
2850-
pass.valueStorageMap.replaceValue(oldSuccessVal, newSuccessVal);
2851-
2852-
auto newFailureVal = failureBuilder.createTrivialLoadOr(
2853-
loc, srcAddr, LoadOwnershipQualifier::Take);
2854-
oldFailureVal->replaceAllUsesWith(newFailureVal);
2855-
failureBB->eraseArgument(0);
2856-
2857-
pass.valueStorageMap.replaceValue(oldFailureVal, newFailureVal);
2858-
markRewritten(newFailureVal, srcAddr);
2859-
} else {
2860-
// If the target is loadable, create a stack temporary to be used as the
2861-
// destination for checked_cast_addr_br.
2862-
SILValue destAddr = termBuilder.createAllocStack(
2863-
loc, checkedCastBranch->getTargetLoweredType());
2864-
2865-
termBuilder.createCheckedCastAddrBranch(
2866-
loc, CastConsumptionKind::TakeOnSuccess, srcAddr,
2867-
checkedCastBranch->getSourceFormalType(), destAddr,
2868-
checkedCastBranch->getTargetFormalType(), successBB, failureBB,
2869-
checkedCastBranch->getTrueBBCount(),
2870-
checkedCastBranch->getFalseBBCount());
2871-
2872-
// Replace the success block arg with loaded value from destAddr, and delete
2873-
// the success block arg.
2874-
auto newSuccessVal = successBuilder.createTrivialLoadOr(
2875-
loc, destAddr, LoadOwnershipQualifier::Take);
2876-
oldSuccessVal->replaceAllUsesWith(newSuccessVal);
2877-
successBB->eraseArgument(0);
2878-
2879-
successBuilder.createDeallocStack(loc, destAddr);
2880-
failureBuilder.createDeallocStack(loc, destAddr);
2881-
2882-
// Since failure block arg is opaque, create dummy load from its storage
2883-
// address. This will be replaced later with copy_addr in
2884-
// DefRewriter::visitLoadInst.
2885-
auto newFailureVal = failureBuilder.createTrivialLoadOr(
2886-
loc, srcAddr, LoadOwnershipQualifier::Take);
2887-
oldFailureVal->replaceAllUsesWith(newFailureVal);
2888-
failureBB->eraseArgument(0);
2889-
2890-
pass.valueStorageMap.replaceValue(oldFailureVal, newFailureVal);
2891-
markRewritten(newFailureVal, srcAddr);
2892-
}
2893-
2894-
pass.deleter.forceDelete(checkedCastBranch);
2921+
void UseRewriter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccb) {
2922+
CheckedCastBrRewriter(ccb, pass).rewrite();
28952923
}
28962924

28972925
void UseRewriter::visitUncheckedEnumDataInst(
@@ -2989,18 +3017,9 @@ class DefRewriter : SILInstructionVisitor<DefRewriter> {
29893017
LLVM_DEBUG(llvm::dbgs() << "REWRITE ARG "; arg->dump());
29903018
if (storage.storageAddress)
29913019
LLVM_DEBUG(llvm::dbgs() << " STORAGE "; storage.storageAddress->dump());
2992-
29933020
storage.storageAddress = addrMat.materializeAddress(arg);
29943021
}
29953022

2996-
void setStorageAddress(SILValue oldValue, SILValue addr) {
2997-
auto &storage = pass.valueStorageMap.getStorage(oldValue);
2998-
// getReusedStorageOperand() ensures that oldValue does not already have
2999-
// separate storage. So there's no need to delete its alloc_stack.
3000-
assert(!storage.storageAddress || storage.storageAddress == addr);
3001-
storage.storageAddress = addr;
3002-
}
3003-
30043023
void beforeVisit(SILInstruction *inst) {
30053024
LLVM_DEBUG(llvm::dbgs() << "REWRITE DEF "; inst->dump());
30063025
if (storage.storageAddress)
@@ -3068,7 +3087,7 @@ class DefRewriter : SILInstructionVisitor<DefRewriter> {
30683087
openExistentialBoxValue->getType().getAddressType());
30693088

30703089
openExistentialBoxValue->replaceAllTypeDependentUsesWith(openAddr);
3071-
setStorageAddress(openExistentialBoxValue, openAddr);
3090+
pass.valueStorageMap.setStorageAddress(openExistentialBoxValue, openAddr);
30723091
}
30733092

30743093
// Load an opaque value.
@@ -3133,7 +3152,7 @@ class DefRewriter : SILInstructionVisitor<DefRewriter> {
31333152
// Rewrite Opaque Values
31343153
//===----------------------------------------------------------------------===//
31353154

3136-
// Rewrite applies with indirect paramters or results of loadable types which
3155+
// Rewrite applies with indirect parameters or results of loadable types which
31373156
// were not visited during opaque value rewritting.
31383157
static void rewriteIndirectApply(FullApplySite apply,
31393158
AddressLoweringState &pass) {
@@ -3186,6 +3205,13 @@ static void rewriteFunction(AddressLoweringState &pass) {
31863205
rewriteIndirectApply(optionalApply.getValue(), pass);
31873206
}
31883207
}
3208+
3209+
// Rewrite all checked_cast_br instructions with loadable source type and
3210+
// opaque target type now
3211+
for (auto *ccb : pass.opaqueResultCCBs) {
3212+
CheckedCastBrRewriter(ccb, pass).rewrite();
3213+
}
3214+
31893215
// Rewrite this function's return value now that all opaque values within the
31903216
// function are rewritten. This still depends on a valid ValueStorage
31913217
// projection operands.

lib/SILOptimizer/Mandatory/AddressLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,12 @@ class ValueStorageMap {
256256
return getNonEnumBaseStorage(getStorage(value));
257257
}
258258

259+
void setStorageAddress(SILValue value, SILValue addr) {
260+
auto &storage = getStorage(value);
261+
assert(!storage.storageAddress || storage.storageAddress == addr);
262+
storage.storageAddress = addr;
263+
}
264+
259265
/// Insert a value in the map, creating a ValueStorage object for it. This
260266
/// must be called in RPO order.
261267
void insertValue(SILValue value, SILValue storageAddress);

test/SILOptimizer/address_lowering.sil

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,47 @@ bb3:
12451245
return %31 : $()
12461246
}
12471247

1248+
sil @use_Any : $@convention(thin) (@in Any) -> ()
1249+
1250+
// CHECK-LABEL: sil [ossa] @test_checked_cast_br3 : $@convention(method) (@owned C) -> () {
1251+
// CHECK: bb0(%0 : @owned $C):
1252+
// CHECK: [[DST:%.*]] = alloc_stack $Any
1253+
// CHECK: [[SRC_TMP:%.*]] = alloc_stack $C
1254+
// CHECK: store %0 to [init] [[SRC_TMP]] : $*C
1255+
// CHECK: checked_cast_addr_br take_on_success C in [[SRC_TMP]] : $*C to Any in [[DST]] : $*Any, bb2, bb1
1256+
// CHECK: bb1:
1257+
// CHECK: [[LD:%.*]] = load [take] [[SRC_TMP]] : $*C
1258+
// CHECK: dealloc_stack [[SRC_TMP]] : $*C
1259+
// CHECK: destroy_value [[LD]] : $C
1260+
// CHECK: br bb3
1261+
// CHECK: bb2:
1262+
// CHECK: dealloc_stack [[SRC_TMP]] : $*C
1263+
// CHECK: [[FUNC:%.*]] = function_ref @use_Any : $@convention(thin) (@in Any) -> ()
1264+
// CHECK: apply [[FUNC]]([[DST]]) : $@convention(thin) (@in Any) -> ()
1265+
// CHECK: br bb3
1266+
// CHECK: bb3:
1267+
// CHECK: [[RES:%.*]] = tuple ()
1268+
// CHECK: dealloc_stack [[DST]] : $*Any
1269+
// CHECK: return [[RES]] : $()
1270+
// CHECK: }
1271+
sil [ossa] @test_checked_cast_br3 : $@convention(method) (@owned C) -> () {
1272+
bb0(%0 : @owned $C):
1273+
checked_cast_br %0 : $C to Any, bb1, bb2
1274+
1275+
bb1(%3 : @owned $Any):
1276+
%f = function_ref @use_Any : $@convention(thin) (@in Any) -> ()
1277+
%call = apply %f(%3) : $@convention(thin) (@in Any) -> ()
1278+
br bb3
1279+
1280+
bb2(%4 : @owned $C):
1281+
destroy_value %4 : $C
1282+
br bb3
1283+
1284+
bb3:
1285+
%31 = tuple ()
1286+
return %31 : $()
1287+
}
1288+
12481289
// CHECK-LABEL: sil hidden [ossa] @test_unchecked_bitwise_cast :
12491290
// CHECK: bb0(%0 : $*U, %1 : $*T, %2 : $@thick U.Type):
12501291
// CHECK: [[STK:%.*]] = alloc_stack $T

0 commit comments

Comments
 (0)