Skip to content

Commit e334da4

Browse files
author
Joe Shajrawi
committed
Resolves an issue with large loadable types wherein functions types inside classes misbehaved
1 parent a05b35c commit e334da4

File tree

3 files changed

+189
-63
lines changed

3 files changed

+189
-63
lines changed

lib/IRGen/LoadableByAddress.cpp

Lines changed: 165 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,64 @@ static SILType getNewSILFunctionType(GenericEnvironment *GenericEnv,
211211
return newSILType;
212212
}
213213

214+
// Get the funciton type or the optional function type
215+
static SILFunctionType *getInnerFunctionType(SILType storageType) {
216+
CanType currCanType = storageType.getSwiftRValueType();
217+
if (SILFunctionType *currSILFunctionType =
218+
dyn_cast<SILFunctionType>(currCanType.getPointer())) {
219+
return currSILFunctionType;
220+
}
221+
OptionalTypeKind optKind;
222+
if (auto optionalType = currCanType.getAnyOptionalObjectType(optKind)) {
223+
assert(optKind != OptionalTypeKind::OTK_None &&
224+
"Expected Real Optional Type");
225+
if (auto *currSILFunctionType =
226+
dyn_cast<SILFunctionType>(optionalType.getPointer())) {
227+
return currSILFunctionType;
228+
}
229+
}
230+
return nullptr;
231+
}
232+
233+
static SILType getNewOptionalFunctionType(GenericEnvironment *GenericEnv,
234+
SILType storageType,
235+
irgen::IRGenModule &Mod) {
236+
SILType newSILType = storageType;
237+
CanType currCanType = storageType.getSwiftRValueType();
238+
OptionalTypeKind optKind;
239+
if (auto optionalType = currCanType.getAnyOptionalObjectType(optKind)) {
240+
assert(optKind != OptionalTypeKind::OTK_None &&
241+
"Expected Real Optional Type");
242+
if (auto *currSILFunctionType =
243+
dyn_cast<SILFunctionType>(optionalType.getPointer())) {
244+
if (containsLargeLoadable(GenericEnv,
245+
currSILFunctionType->getParameters(), Mod)) {
246+
newSILType =
247+
getNewSILFunctionType(GenericEnv, currSILFunctionType, Mod);
248+
currCanType = newSILType.getSwiftRValueType();
249+
auto newType = OptionalType::get(optKind, currCanType);
250+
CanType newCanType = newType->getCanonicalType();
251+
newSILType = SILType::getPrimitiveObjectType(newCanType);
252+
}
253+
}
254+
}
255+
return newSILType;
256+
}
257+
214258
static SmallVector<SILParameterInfo, 4>
215259
getNewArgTys(GenericEnvironment *GenericEnv, ArrayRef<SILParameterInfo> params,
216260
irgen::IRGenModule &Mod) {
217261
SmallVector<SILParameterInfo, 4> newArgTys;
218262
for (SILParameterInfo param : params) {
219263
SILType storageType = param.getSILStorageType();
264+
SILType newOptFuncType =
265+
getNewOptionalFunctionType(GenericEnv, storageType, Mod);
266+
if (newOptFuncType != storageType) {
267+
auto newParam = SILParameterInfo(newOptFuncType.getSwiftRValueType(),
268+
param.getConvention());
269+
newArgTys.push_back(newParam);
270+
continue;
271+
}
220272
CanType currCanType = storageType.getSwiftRValueType();
221273
if (SILFunctionType *currSILFunctionType =
222274
dyn_cast<SILFunctionType>(currCanType.getPointer())) {
@@ -245,26 +297,11 @@ getNewArgTys(GenericEnvironment *GenericEnv, ArrayRef<SILParameterInfo> params,
245297

246298
static SILType getNewSILType(GenericEnvironment *GenericEnv,
247299
SILType storageType, irgen::IRGenModule &Mod) {
248-
SILType newSILType = storageType;
249-
CanType currCanType = storageType.getSwiftRValueType();
250-
OptionalTypeKind optKind;
251-
if (auto optionalType = currCanType.getAnyOptionalObjectType(optKind)) {
252-
assert(optKind != OptionalTypeKind::OTK_None &&
253-
"Expected Real Optional Type");
254-
if (SILFunctionType *currSILFunctionType =
255-
dyn_cast<SILFunctionType>(optionalType.getPointer())) {
256-
if (containsLargeLoadable(GenericEnv,
257-
currSILFunctionType->getParameters(), Mod)) {
258-
newSILType =
259-
getNewSILFunctionType(GenericEnv, currSILFunctionType, Mod);
260-
currCanType = newSILType.getSwiftRValueType();
261-
auto newType = OptionalType::get(optKind, currCanType);
262-
CanType newCanType = newType->getCanonicalType();
263-
newSILType = SILType::getPrimitiveObjectType(newCanType);
264-
return newSILType;
265-
}
266-
}
300+
SILType newSILType = getNewOptionalFunctionType(GenericEnv, storageType, Mod);
301+
if (newSILType != storageType) {
302+
return newSILType;
267303
}
304+
CanType currCanType = storageType.getSwiftRValueType();
268305
if (auto *currSILBlockType =
269306
dyn_cast<SILBlockStorageType>(currCanType.getPointer())) {
270307
return storageType;
@@ -309,6 +346,10 @@ struct StructLoweringState {
309346
SmallVector<StructExtractInst *, 16> structExtractInstsToMod;
310347
// All tuple instructions for which the return type is a function type
311348
SmallVector<SILInstruction *, 8> tupleInstsToMod;
349+
// All allock stack instructions to modify
350+
SmallVector<AllocStackInst *, 8> allocStackInstsToMod;
351+
// All pointer to address instructions to modify
352+
SmallVector<PointerToAddressInst *, 8> pointerToAddrkInstsToMod;
312353
// All Retain and release instrs should be replaced with _addr version
313354
SmallVector<RetainValueInst *, 16> retainInstsToMod;
314355
SmallVector<ReleaseValueInst *, 16> releaseInstsToMod;
@@ -353,6 +394,8 @@ class LargeValueVisitor {
353394
void visitResultTyInst(SILInstruction *instr);
354395
void visitDebugValueInst(DebugValueInst *instr);
355396
void visitTupleInst(SILInstruction *instr);
397+
void visitAllocStackInst(AllocStackInst *instr);
398+
void visitPointerToAddressInst(PointerToAddressInst *instr);
356399
void visitInstr(SILInstruction *instr);
357400
};
358401
} // end anonymous namespace
@@ -419,6 +462,16 @@ void LargeValueVisitor::mapValueStorage() {
419462
visitTupleInst(currIns);
420463
break;
421464
}
465+
case ValueKind::AllocStackInst: {
466+
auto *ASI = dyn_cast<AllocStackInst>(currIns);
467+
visitAllocStackInst(ASI);
468+
break;
469+
}
470+
case ValueKind::PointerToAddressInst: {
471+
auto *PTA = dyn_cast<PointerToAddressInst>(currIns);
472+
visitPointerToAddressInst(PTA);
473+
break;
474+
}
422475
default: {
423476
assert(!ApplySite::isa(currIns) && "Did not expect an ApplySite");
424477
assert(!dyn_cast<MethodInst>(currIns) && "Unhandled Method Inst");
@@ -633,6 +686,20 @@ void LargeValueVisitor::visitTupleInst(SILInstruction *instr) {
633686
visitInstr(instr);
634687
}
635688

689+
void LargeValueVisitor::visitAllocStackInst(AllocStackInst *instr) {
690+
SILType currSILType = instr->getType().getObjectType();
691+
if (auto *fType = getInnerFunctionType(currSILType)) {
692+
pass.allocStackInstsToMod.push_back(instr);
693+
}
694+
}
695+
696+
void LargeValueVisitor::visitPointerToAddressInst(PointerToAddressInst *instr) {
697+
SILType currSILType = instr->getType().getObjectType();
698+
if (auto *fType = getInnerFunctionType(currSILType)) {
699+
pass.pointerToAddrkInstsToMod.push_back(instr);
700+
}
701+
}
702+
636703
void LargeValueVisitor::visitInstr(SILInstruction *instr) {
637704
for (Operand &operand : instr->getAllOperands()) {
638705
if (std::find(pass.largeLoadableArgs.begin(), pass.largeLoadableArgs.end(),
@@ -1080,13 +1147,16 @@ class LoadableByAddress : public SILModuleTransform {
10801147
void recreateConvInstrs();
10811148
void recreateLoadInstrs();
10821149
void recreateUncheckedEnumDataInstrs();
1150+
void recreateUncheckedTakeEnumDataAddrInst();
10831151
void fixStoreToBlockStorageInstrs();
10841152

10851153
private:
10861154
llvm::SetVector<SILFunction *> modFuncs;
10871155
llvm::SetVector<SILInstruction *> conversionInstrs;
10881156
llvm::SetVector<LoadInst *> loadInstrsOfFunc;
10891157
llvm::SetVector<UncheckedEnumDataInst *> uncheckedEnumDataOfFunc;
1158+
llvm::SetVector<UncheckedTakeEnumDataAddrInst *>
1159+
uncheckedTakeEnumDataAddrOfFunc;
10901160
llvm::SetVector<StoreInst *> storeToBlockStorageInstrs;
10911161
llvm::DenseSet<SILInstruction *> modApplies;
10921162
};
@@ -1414,6 +1484,28 @@ static void rewriteFunction(StructLoweringState &pass,
14141484
castTupleInstr(instr, pass.Mod);
14151485
}
14161486

1487+
while (!pass.allocStackInstsToMod.empty()) {
1488+
auto *instr = pass.allocStackInstsToMod.pop_back_val();
1489+
SILBuilder allocBuilder(instr);
1490+
SILType currSILType = instr->getType();
1491+
SILType newSILType = getNewSILType(genEnv, currSILType, pass.Mod);
1492+
auto *newInstr = allocBuilder.createAllocStack(instr->getLoc(), newSILType);
1493+
instr->replaceAllUsesWith(newInstr);
1494+
instr->getParent()->erase(instr);
1495+
}
1496+
1497+
while (!pass.pointerToAddrkInstsToMod.empty()) {
1498+
auto *instr = pass.pointerToAddrkInstsToMod.pop_back_val();
1499+
SILBuilder pointerBuilder(instr);
1500+
SILType currSILType = instr->getType();
1501+
SILType newSILType = getNewSILType(genEnv, currSILType, pass.Mod);
1502+
auto *newInstr = pointerBuilder.createPointerToAddress(
1503+
instr->getLoc(), instr->getOperand(), newSILType.getAddressType(),
1504+
instr->isStrict());
1505+
instr->replaceAllUsesWith(newInstr);
1506+
instr->getParent()->erase(instr);
1507+
}
1508+
14171509
for (SILInstruction *instr : pass.debugInstsToMod) {
14181510
assert(instr->getAllOperands().size() == 1 &&
14191511
"Debug instructions have one operand");
@@ -1583,37 +1675,21 @@ static bool rewriteFunctionReturn(StructLoweringState &pass) {
15831675
}
15841676
SILFunction *F = pass.F;
15851677
SILType resultTy = loweredTy->getAllResultsType();
1586-
CanType resultCanTy = resultTy.getSwiftRValueType();
1587-
if (SILFunctionType *currSILFunctionType =
1588-
dyn_cast<SILFunctionType>(resultCanTy.getPointer())) {
1589-
if (containsLargeLoadable(genEnv, currSILFunctionType->getParameters(),
1590-
pass.Mod)) {
1591-
assert(F->getLoweredFunctionType()->getNumResults() == 1 &&
1592-
"Expected a single result");
1593-
SILResultInfo origResultInfo = loweredTy->getSingleResult();
1594-
SmallVector<SILParameterInfo, 4> newArgTys =
1595-
getNewArgTys(genEnv, currSILFunctionType->getParameters(), pass.Mod);
1596-
SILFunctionType *newSILFunctionType =
1597-
SILFunctionType::get(currSILFunctionType->getGenericSignature(),
1598-
currSILFunctionType->getExtInfo(),
1599-
currSILFunctionType->getCalleeConvention(),
1600-
newArgTys, currSILFunctionType->getResults(),
1601-
currSILFunctionType->getOptionalErrorResult(),
1602-
currSILFunctionType->getASTContext());
1603-
SILType newSILType = SILType::getPrimitiveObjectType(
1604-
newSILFunctionType->getCanonicalType());
1605-
SILResultInfo newSILResultInfo(newSILType.getSwiftRValueType(),
1606-
origResultInfo.getConvention());
1607-
// change the caller's SIL function type
1608-
SILFunctionType *OrigFTI = F->getLoweredFunctionType();
1609-
auto NewTy = SILFunctionType::get(
1610-
OrigFTI->getGenericSignature(), OrigFTI->getExtInfo(),
1611-
OrigFTI->getCalleeConvention(), OrigFTI->getParameters(),
1612-
newSILResultInfo, OrigFTI->getOptionalErrorResult(),
1613-
F->getModule().getASTContext());
1614-
F->rewriteLoweredTypeUnsafe(NewTy);
1615-
return true;
1616-
}
1678+
SILType newSILType = getNewSILType(genEnv, resultTy, pass.Mod);
1679+
// We (currently) only care about function signatures
1680+
if (!isLargeLoadableType(genEnv, resultTy, pass.Mod) &&
1681+
(newSILType != resultTy)) {
1682+
assert(loweredTy->getNumResults() == 1 && "Expected a single result");
1683+
SILResultInfo origResultInfo = loweredTy->getSingleResult();
1684+
SILResultInfo newSILResultInfo(newSILType.getSwiftRValueType(),
1685+
origResultInfo.getConvention());
1686+
auto NewTy = SILFunctionType::get(
1687+
loweredTy->getGenericSignature(), loweredTy->getExtInfo(),
1688+
loweredTy->getCalleeConvention(), loweredTy->getParameters(),
1689+
newSILResultInfo, loweredTy->getOptionalErrorResult(),
1690+
F->getModule().getASTContext());
1691+
F->rewriteLoweredTypeUnsafe(NewTy);
1692+
return true;
16171693
}
16181694
return false;
16191695
}
@@ -1796,6 +1872,29 @@ void LoadableByAddress::recreateUncheckedEnumDataInstrs() {
17961872
}
17971873
}
17981874

1875+
void LoadableByAddress::recreateUncheckedTakeEnumDataAddrInst() {
1876+
for (auto *enumInstr : uncheckedTakeEnumDataAddrOfFunc) {
1877+
SILBuilder enumBuilder(enumInstr);
1878+
SILFunction *F = enumInstr->getFunction();
1879+
CanSILFunctionType funcType = F->getLoweredFunctionType();
1880+
IRGenModule *currIRMod = getIRGenModule()->IRGen.getGenModule(F);
1881+
Lowering::GenericContextScope GenericScope(getModule()->Types,
1882+
funcType->getGenericSignature());
1883+
SILType origType = enumInstr->getType();
1884+
GenericEnvironment *genEnv = F->getGenericEnvironment();
1885+
auto loweredTy = F->getLoweredFunctionType();
1886+
if (!genEnv && loweredTy->isPolymorphic()) {
1887+
genEnv = getGenericEnvironment(F->getModule(), loweredTy);
1888+
}
1889+
SILType newType = getNewSILType(genEnv, origType, *currIRMod);
1890+
auto *newInstr = enumBuilder.createUncheckedTakeEnumDataAddr(
1891+
enumInstr->getLoc(), enumInstr->getOperand(), enumInstr->getElement(),
1892+
newType.getAddressType());
1893+
enumInstr->replaceAllUsesWith(newInstr);
1894+
enumInstr->getParent()->erase(enumInstr);
1895+
}
1896+
}
1897+
17991898
void LoadableByAddress::fixStoreToBlockStorageInstrs() {
18001899
for (auto *instr : storeToBlockStorageInstrs) {
18011900
auto dest = instr->getDest();
@@ -1934,24 +2033,28 @@ void LoadableByAddress::run() {
19342033
}
19352034
} else if (auto *LI = dyn_cast<LoadInst>(&I)) {
19362035
SILType currType = LI->getType();
1937-
CanType currCanType = currType.getSwiftRValueType();
1938-
if (auto *fType =
1939-
dyn_cast<SILFunctionType>(currCanType.getPointer())) {
2036+
if (auto *fType = getInnerFunctionType(currType)) {
19402037
if (modifiableFunction(CanSILFunctionType(fType))) {
19412038
// need to re-create these loads: re-write type cache
19422039
loadInstrsOfFunc.insert(LI);
19432040
}
19442041
}
19452042
} else if (auto *UED = dyn_cast<UncheckedEnumDataInst>(&I)) {
19462043
SILType currType = UED->getType();
1947-
CanType currCanType = currType.getSwiftRValueType();
1948-
if (auto *fType =
1949-
dyn_cast<SILFunctionType>(currCanType.getPointer())) {
2044+
if (auto *fType = getInnerFunctionType(currType)) {
19502045
if (modifiableFunction(CanSILFunctionType(fType))) {
19512046
// need to re-create these loads: re-write type cache
19522047
uncheckedEnumDataOfFunc.insert(UED);
19532048
}
19542049
}
2050+
} else if (auto *UED = dyn_cast<UncheckedTakeEnumDataAddrInst>(&I)) {
2051+
SILType currType = UED->getType();
2052+
if (auto *fType = getInnerFunctionType(currType)) {
2053+
if (modifiableFunction(CanSILFunctionType(fType))) {
2054+
// need to re-create these loads: re-write type cache
2055+
uncheckedTakeEnumDataAddrOfFunc.insert(UED);
2056+
}
2057+
}
19552058
} else if (auto *SI = dyn_cast<StoreInst>(&I)) {
19562059
auto dest = SI->getDest();
19572060
if (isa<ProjectBlockStorageInst>(dest)) {
@@ -1983,12 +2086,15 @@ void LoadableByAddress::run() {
19832086
// Re-create all conversions for which we modified the FunctionRef
19842087
recreateConvInstrs();
19852088

1986-
// Re-create all load instrs of function pointers
1987-
recreateLoadInstrs();
1988-
19892089
// Re-create all unchecked enum data instrs of function pointers
19902090
recreateUncheckedEnumDataInstrs();
19912091

2092+
// Same for data addr
2093+
recreateUncheckedTakeEnumDataAddrInst();
2094+
2095+
// Re-create all load instrs of function pointers
2096+
recreateLoadInstrs();
2097+
19922098
// Re-create all applies that we modified in the module
19932099
recreateApplies();
19942100

lib/SIL/SILVerifier.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,10 +2102,12 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
21022102
require(EI->getField()->getDeclContext() == cd,
21032103
"ref_element_addr field must be a member of the class");
21042104

2105-
SILType loweredFieldTy = operandTy.getFieldType(EI->getField(),
2106-
F.getModule());
2107-
require(loweredFieldTy == EI->getType(),
2108-
"result of ref_element_addr does not match type of field");
2105+
if (EI->getModule().getStage() != SILStage::Lowered) {
2106+
SILType loweredFieldTy =
2107+
operandTy.getFieldType(EI->getField(), F.getModule());
2108+
require(loweredFieldTy == EI->getType(),
2109+
"result of ref_element_addr does not match type of field");
2110+
}
21092111
EI->getFieldNo(); // Make sure we can access the field without crashing.
21102112
}
21112113

test/IRGen/big_types_corner_cases.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,24 @@ public func f4_tuple_use_of_f2() {
4646
// CHECK: call swiftcc void [[CAST_EXTRACT]](%T22big_types_corner_cases9BigStructV* noalias nocapture sret %call.aggresult1, %T22big_types_corner_cases9BigStructV* noalias nocapture dereferenceable
4747
// CHECK: ret void
4848

49+
public class BigClass {
50+
public init() {
51+
}
52+
53+
public var optVar: ((BigStruct)-> Void)? = nil
54+
55+
func useBigStruct(bigStruct: BigStruct) {
56+
optVar!(bigStruct)
57+
}
58+
}
59+
60+
// CHECK-LABEL define{{( protected)?}} hidden swiftcc void @_T022big_types_corner_cases8BigClassC03useE6StructyAA0eH0V0aH0_tF(%T22big_types_corner_cases9BigStructV* noalias nocapture dereferenceable({{.*}}), %T22big_types_corner_cases8BigClassC* swiftself) #0 {
61+
// CHECK: getelementptr inbounds %T22big_types_corner_cases8BigClassC, %T22big_types_corner_cases8BigClassC*
62+
// CHECK: call void @_T0SqWy
63+
// CHECK: [[BITCAST:%.*]] = bitcast i8* {{.*}} to void (%T22big_types_corner_cases9BigStructV*, %swift.refcounted*)*
64+
// CHECK: call swiftcc void [[BITCAST]](%T22big_types_corner_cases9BigStructV* noalias nocapture dereferenceable({{.*}}) %0, %swift.refcounted* swiftself
65+
// CHECK: ret void
66+
4967
public struct MyStruct {
5068
public let a: Int
5169
public let b: String?

0 commit comments

Comments
 (0)