Skip to content

Commit 4a1b98c

Browse files
authored
Merge pull request #22757 from gottesmm/pr-6e89f4298efb94406a3ad6b55617ae741dfb65e0
[const-expr] Teach ConstExpr how to handle switch_enum/unchecked_enum…
2 parents 1a66c77 + c0988d1 commit 4a1b98c

File tree

2 files changed

+105
-10
lines changed

2 files changed

+105
-10
lines changed

lib/SILOptimizer/Utils/ConstExpr.cpp

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "swift/AST/ProtocolConformance.h"
1616
#include "swift/AST/SubstitutionMap.h"
1717
#include "swift/Basic/Defer.h"
18+
#include "swift/Basic/NullablePtr.h"
1819
#include "swift/Demangling/Demangle.h"
1920
#include "swift/SIL/ApplySite.h"
2021
#include "swift/SIL/FormalLinkage.h"
@@ -213,6 +214,19 @@ SymbolicValue ConstExprFunctionState::computeConstantValue(SILValue value) {
213214
return val;
214215
}
215216

217+
// If this is an unchecked_enum_data from a fragile type, then we can return
218+
// the enum case value.
219+
if (auto *uedi = dyn_cast<UncheckedEnumDataInst>(value)) {
220+
auto aggValue = uedi->getOperand();
221+
auto val = getConstantValue(aggValue);
222+
if (val.isConstant()) {
223+
assert(val.getKind() == SymbolicValue::EnumWithPayload);
224+
return val.getEnumPayloadValue();
225+
}
226+
// Not a const.
227+
return val;
228+
}
229+
216230
// If this is a destructure_result, then we can return the element being
217231
// extracted.
218232
if (isa<DestructureStructResult>(value) ||
@@ -1359,20 +1373,34 @@ static llvm::Optional<SymbolicValue> evaluateAndCacheCall(
13591373
}
13601374
if (!value.isConstant())
13611375
return value;
1376+
13621377
assert(value.getKind() == SymbolicValue::Enum ||
13631378
value.getKind() == SymbolicValue::EnumWithPayload);
1364-
// Set up basic block arguments.
1379+
13651380
auto *caseBB = switchInst->getCaseDestination(value.getEnumValue());
1366-
if (caseBB->getNumArguments() > 0) {
1367-
assert(value.getKind() == SymbolicValue::EnumWithPayload);
1368-
// When there are multiple payload components, they form a single
1369-
// tuple-typed argument.
1370-
assert(caseBB->getNumArguments() == 1);
1371-
auto argument = value.getEnumPayloadValue();
1372-
assert(argument.isConstant());
1373-
state.setValue(caseBB->getArgument(0), argument);
1374-
}
1381+
1382+
// Prepare to subsequently visit the case blocks instructions.
13751383
nextInst = caseBB->begin();
1384+
// Then set up the arguments.
1385+
if (caseBB->getParent()->hasOwnership() &&
1386+
switchInst->getDefaultBBOrNull() == caseBB) {
1387+
// If we are visiting the default block and we are in ossa, then we may
1388+
// have uses of the failure parameter. That means we need to map the
1389+
// original value to the argument.
1390+
state.setValue(caseBB->getArgument(0), value);
1391+
continue;
1392+
}
1393+
1394+
if (caseBB->getNumArguments() == 0)
1395+
continue;
1396+
1397+
assert(value.getKind() == SymbolicValue::EnumWithPayload);
1398+
// When there are multiple payload components, they form a single
1399+
// tuple-typed argument.
1400+
assert(caseBB->getNumArguments() == 1);
1401+
auto argument = value.getEnumPayloadValue();
1402+
assert(argument.isConstant());
1403+
state.setValue(caseBB->getArgument(0), argument);
13761404
continue;
13771405
}
13781406

test/SILOptimizer/pound_assert_ossa.sil

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,3 +535,70 @@ bb0(%arg : $BiggerStruct):
535535
%ret = tuple ()
536536
return %ret : $()
537537
}
538+
539+
enum Either {
540+
case left(Builtin.Int64)
541+
case right(Builtin.Int64)
542+
}
543+
544+
// Make sure that we properly handle failure default cases.
545+
sil [ossa] @switch_enum_test_callee_1 : $@convention(thin) () -> Builtin.Int64 {
546+
bb0:
547+
%0 = integer_literal $Builtin.Int64, 0
548+
%1 = enum $Either, #Either.left!enumelt.1, %0 : $Builtin.Int64
549+
switch_enum %1 : $Either, case #Either.left!enumelt.1: bb1, default bb2
550+
551+
bb1(%2 : $Builtin.Int64):
552+
br bb3(%2 : $Builtin.Int64)
553+
554+
bb2(%3 : $Either):
555+
%4 = integer_literal $Builtin.Int64, 1
556+
br bb3(%4 : $Builtin.Int64)
557+
558+
bb3(%5 : $Builtin.Int64):
559+
return %5 : $Builtin.Int64
560+
}
561+
562+
sil [ossa] @switch_enum_test_callee_2 : $@convention(thin) () -> Builtin.Int64 {
563+
bb0:
564+
%0 = integer_literal $Builtin.Int64, 0
565+
%1 = enum $Either, #Either.left!enumelt.1, %0 : $Builtin.Int64
566+
// Make sure we go down the bad path.
567+
switch_enum %1 : $Either, case #Either.right!enumelt.1: bb4, default bb5
568+
569+
bb4(%7 : $Builtin.Int64):
570+
br bb6(%7 : $Builtin.Int64)
571+
572+
bb5(%8 : $Either):
573+
%9 = unchecked_enum_data %8 : $Either, #Either.right!enumelt.1
574+
br bb6(%9 : $Builtin.Int64)
575+
576+
bb6(%10 : $Builtin.Int64):
577+
return %10 : $Builtin.Int64
578+
}
579+
580+
sil [ossa] @switch_enum_test_caller : $@convention(thin) () -> () {
581+
bb0:
582+
%0 = function_ref @switch_enum_test_callee_1 : $@convention(thin) () -> Builtin.Int64
583+
%0a = function_ref @switch_enum_test_callee_2 : $@convention(thin) () -> Builtin.Int64
584+
%2 = apply %0() : $@convention(thin) () -> Builtin.Int64
585+
%3 = apply %0a() : $@convention(thin) () -> Builtin.Int64
586+
%str = string_literal utf8 ""
587+
%resultPositive = integer_literal $Builtin.Int64, 0
588+
%resultNegative = integer_literal $Builtin.Int64, 1
589+
%cmp1Positive = builtin "cmp_eq_Int64"(%2 : $Builtin.Int64, %resultPositive : $Builtin.Int64) : $Builtin.Int1
590+
builtin "poundAssert"(%cmp1Positive : $Builtin.Int1, %str : $Builtin.RawPointer) : $()
591+
// Make sure we simplified down the bb1 path.
592+
%cmp2Positive = builtin "cmp_eq_Int64"(%3 : $Builtin.Int64, %resultPositive : $Builtin.Int64) : $Builtin.Int1
593+
builtin "poundAssert"(%cmp2Positive : $Builtin.Int1, %str : $Builtin.RawPointer) : $()
594+
595+
%cmp1Negative = builtin "cmp_eq_Int64"(%2 : $Builtin.Int64, %resultNegative : $Builtin.Int64) : $Builtin.Int1
596+
// expected-error @+1 {{assertion failed}}
597+
builtin "poundAssert"(%cmp1Negative : $Builtin.Int1, %str : $Builtin.RawPointer) : $()
598+
%cmp2Negative = builtin "cmp_eq_Int64"(%3 : $Builtin.Int64, %resultNegative : $Builtin.Int64) : $Builtin.Int1
599+
// expected-error @+1 {{assertion failed}}
600+
builtin "poundAssert"(%cmp2Negative : $Builtin.Int1, %str : $Builtin.RawPointer) : $()
601+
602+
%9999 = tuple()
603+
return %9999 : $()
604+
}

0 commit comments

Comments
 (0)