Skip to content

Commit 068ec2a

Browse files
committed
Transform tuple return types that contain function signatures.
1 parent f821672 commit 068ec2a

File tree

2 files changed

+74
-26
lines changed

2 files changed

+74
-26
lines changed

lib/IRGen/LoadableByAddress.cpp

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,15 @@ bool LargeSILTypeMapper::shouldTransformResults(GenericEnvironment *genEnv,
296296
if (!modifiableFunction(loweredTy)) {
297297
return false;
298298
}
299+
299300
if (loweredTy->getNumResults() != 1) {
300-
return false;
301+
auto resultType = loweredTy->getAllResultsType();
302+
auto newResultType = getNewSILType(genEnv, resultType, Mod);
303+
bool hasFuncSig = containsFunctionSignature(genEnv, Mod,
304+
resultType, newResultType);
305+
return hasFuncSig;
301306
}
307+
302308
auto singleResult = loweredTy->getSingleResult();
303309
auto resultStorageType = singleResult.getSILStorageType();
304310
auto newResultStorageType = getNewSILType(genEnv, resultStorageType, Mod);
@@ -1287,8 +1293,7 @@ SILArgument *LoadableStorageAllocation::replaceArgType(SILBuilder &argBuilder,
12871293
void LoadableStorageAllocation::insertIndirectReturnArgs() {
12881294
GenericEnvironment *genEnv = pass.F->getGenericEnvironment();
12891295
auto loweredTy = pass.F->getLoweredFunctionType();
1290-
auto singleResult = loweredTy->getSingleResult();
1291-
SILType resultStorageType = singleResult.getSILStorageType();
1296+
SILType resultStorageType = loweredTy->getAllResultsType();
12921297
auto canType = resultStorageType.getASTType();
12931298
if (canType->hasTypeParameter()) {
12941299
assert(genEnv && "Expected a GenericEnv");
@@ -1368,19 +1373,26 @@ void LoadableStorageAllocation::convertApplyResults() {
13681373
pass.Mod)) {
13691374
continue;
13701375
}
1371-
auto singleResult = origSILFunctionType->getSingleResult();
1372-
auto resultStorageType = singleResult.getSILStorageType();
1376+
auto resultStorageType = origSILFunctionType->getAllResultsType();
13731377
if (!isLargeLoadableType(genEnv, resultStorageType, pass.Mod)) {
1374-
// Make sure it is a function type
1375-
if (!resultStorageType.is<SILFunctionType>()) {
1376-
// Check if it is an optional function type
1377-
auto optionalType = resultStorageType.getOptionalObjectType();
1378-
assert(optionalType &&
1379-
"Expected SILFunctionType or Optional for the result type");
1380-
assert(optionalType.is<SILFunctionType>() &&
1381-
"Expected a SILFunctionType inside the optional Type");
1382-
(void)optionalType;
1383-
}
1378+
// Make sure it contains a function type
1379+
auto numFuncTy = llvm::count_if(origSILFunctionType->getResults(),
1380+
[](const SILResultInfo &origResult) {
1381+
auto resultStorageTy = origResult.getSILStorageType();
1382+
// Check if it is a function type
1383+
if (resultStorageTy.is<SILFunctionType>()) {
1384+
return true;
1385+
}
1386+
// Check if it is an optional function type
1387+
auto optionalType = resultStorageTy.getOptionalObjectType();
1388+
if (optionalType && optionalType.is<SILFunctionType>()) {
1389+
return true;
1390+
}
1391+
return false;
1392+
});
1393+
assert(numFuncTy != 0 &&
1394+
"Expected a SILFunctionType inside the result Type");
1395+
(void)numFuncTy;
13841396
continue;
13851397
}
13861398
auto newSILType =
@@ -2213,10 +2225,25 @@ static bool rewriteFunctionReturn(StructLoweringState &pass) {
22132225
return true;
22142226
} else if (containsFunctionSignature(genEnv, pass.Mod, resultTy,
22152227
newSILType) &&
2216-
(resultTy != newSILType) && (loweredTy->getNumResults() == 1)) {
2217-
SILResultInfo origResultInfo = loweredTy->getSingleResult();
2218-
SILResultInfo newSILResultInfo(newSILType.getASTType(),
2219-
origResultInfo.getConvention());
2228+
(resultTy != newSILType)) {
2229+
2230+
llvm::SmallVector<SILResultInfo, 2> newSILResultInfo;
2231+
if (auto tupleType = newSILType.getAs<TupleType>()) {
2232+
auto originalResults = loweredTy->getResults();
2233+
for (unsigned int i = 0; i < originalResults.size(); ++i) {
2234+
auto origResultInfo = originalResults[i];
2235+
auto canElem = tupleType.getElementType(i);
2236+
SILType objectType = SILType::getPrimitiveObjectType(canElem);
2237+
auto newResult = SILResultInfo(objectType.getASTType(), origResultInfo.getConvention());
2238+
newSILResultInfo.push_back(newResult);
2239+
}
2240+
} else {
2241+
assert(loweredTy->getNumResults() == 1 && "Expected a single result");
2242+
auto origResultInfo = loweredTy->getSingleResult();
2243+
auto newResult = SILResultInfo(newSILType.getASTType(), origResultInfo.getConvention());
2244+
newSILResultInfo.push_back(newResult);
2245+
}
2246+
22202247
auto NewTy = SILFunctionType::get(
22212248
loweredTy->getGenericSignature(), loweredTy->getExtInfo(),
22222249
loweredTy->getCoroutineKind(),
@@ -2302,9 +2329,6 @@ getOperandTypeWithCastIfNecessary(SILInstruction *containingInstr, SILValue op,
23022329
if (!genEnv && funcType->isPolymorphic()) {
23032330
genEnv = getGenericEnvironment(funcType);
23042331
}
2305-
if (!Mapper.shouldTransformFunctionType(genEnv, funcType, Mod)) {
2306-
return op;
2307-
}
23082332
auto newFnType = Mapper.getNewSILFunctionType(genEnv, funcType, Mod);
23092333
SILType newSILType = SILType::getPrimitiveObjectType(newFnType);
23102334
if (nonOptionalType.isAddress()) {

test/IRGen/big_types_corner_cases.swift

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,17 +288,41 @@ public protocol sr8076_Query {
288288
}
289289

290290
public protocol sr8076_ProtoQueryHandler {
291-
func forceHandle<Q: sr8076_Query>(query: Q) throws -> (Q.Returned, sr8076_Filter?)
291+
func forceHandle_1<Q: sr8076_Query>(query: Q) -> Void
292+
func forceHandle_2<Q: sr8076_Query>(query: Q) -> (Q.Returned, sr8076_BigStruct?)
293+
func forceHandle_3<Q: sr8076_Query>(query: Q) -> (Q.Returned, sr8076_Filter?)
294+
func forceHandle_4<Q: sr8076_Query>(query: Q) throws -> (Q.Returned, sr8076_Filter?)
292295
}
293296

294297
public protocol sr8076_QueryHandler: sr8076_ProtoQueryHandler {
295298
associatedtype Handled: sr8076_Query
296-
func handle(query: Handled) throws -> (Handled.Returned, sr8076_Filter?)
299+
func handle_1(query: Handled) -> Void
300+
func handle_2(query: Handled) -> (Handled.Returned, sr8076_BigStruct?)
301+
func handle_3(query: Handled) -> (Handled.Returned, sr8076_Filter?)
302+
func handle_4(query: Handled) throws -> (Handled.Returned, sr8076_Filter?)
297303
}
298304

299305
public extension sr8076_QueryHandler {
300-
func forceHandle<Q: sr8076_Query>(query: Q) throws -> (Q.Returned, sr8076_Filter?) {
301-
guard let body = handle as? (Q) throws -> (Q.Returned, sr8076_Filter?) else {
306+
func forceHandle_4<Q: sr8076_Query>(query: Q) -> Void {
307+
guard let body = handle_1 as? (Q) -> Void else {
308+
fatalError("handler \(self) is expected to handle query \(query)")
309+
}
310+
body(query)
311+
}
312+
func forceHandle_4<Q: sr8076_Query>(query: Q) -> (Q.Returned, sr8076_BigStruct?) {
313+
guard let body = handle_2 as? (Q) -> (Q.Returned, sr8076_BigStruct?) else {
314+
fatalError("handler \(self) is expected to handle query \(query)")
315+
}
316+
return body(query)
317+
}
318+
func forceHandle_3<Q: sr8076_Query>(query: Q) -> (Q.Returned, sr8076_Filter?) {
319+
guard let body = handle_3 as? (Q) -> (Q.Returned, sr8076_Filter?) else {
320+
fatalError("handler \(self) is expected to handle query \(query)")
321+
}
322+
return body(query)
323+
}
324+
func forceHandle_4<Q: sr8076_Query>(query: Q) throws -> (Q.Returned, sr8076_Filter?) {
325+
guard let body = handle_4 as? (Q) throws -> (Q.Returned, sr8076_Filter?) else {
302326
fatalError("handler \(self) is expected to handle query \(query)")
303327
}
304328
return try body(query)

0 commit comments

Comments
 (0)