Skip to content

Commit f3773da

Browse files
authored
Merge pull request #21869 from apple/marcrasi-const-evaluator-enums
2 parents 015b159 + fb45802 commit f3773da

File tree

4 files changed

+247
-0
lines changed

4 files changed

+247
-0
lines changed

include/swift/SIL/SILConstants.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ class SymbolicValue {
103103
/// "aggregate" member of the value union.
104104
RK_Aggregate,
105105

106+
/// This value is an enum with no payload.
107+
RK_Enum,
108+
109+
/// This value is an enum with a payload.
110+
RK_EnumWithPayload,
111+
106112
/// This represents the address of a memory object.
107113
RK_DirectAddress,
108114

@@ -136,6 +142,14 @@ class SymbolicValue {
136142
/// information about the array elements and count.
137143
const SymbolicValue *aggregate;
138144

145+
/// When this SymbolicValue is of "Enum" kind, this pointer stores
146+
/// information about the enum case type.
147+
EnumElementDecl *enumVal;
148+
149+
/// When this SymbolicValue is of "EnumWithPayload" kind, this pointer
150+
/// stores information about the enum case type and its payload.
151+
EnumWithPayloadSymbolicValue *enumValWithPayload;
152+
139153
/// When the representationKind is "DirectAddress", this pointer is the
140154
/// memory object referenced.
141155
SymbolicValueMemoryObject *directAddress;
@@ -186,6 +200,12 @@ class SymbolicValue {
186200
/// This can be an array, struct, tuple, etc.
187201
Aggregate,
188202

203+
/// This is an enum without payload.
204+
Enum,
205+
206+
/// This is an enum with payload (formally known as "associated value").
207+
EnumWithPayload,
208+
189209
/// This value represents the address of, or into, a memory object.
190210
Address,
191211

@@ -271,6 +291,25 @@ class SymbolicValue {
271291

272292
ArrayRef<SymbolicValue> getAggregateValue() const;
273293

294+
/// This returns a constant Symbolic value for the enum case in `decl`, which
295+
/// must not have an associated value.
296+
static SymbolicValue getEnum(EnumElementDecl *decl) {
297+
assert(decl);
298+
SymbolicValue result;
299+
result.representationKind = RK_Enum;
300+
result.value.enumVal = decl;
301+
return result;
302+
}
303+
304+
/// `payload` must be a constant.
305+
static SymbolicValue getEnumWithPayload(EnumElementDecl *decl,
306+
SymbolicValue payload,
307+
ASTContext &astContext);
308+
309+
EnumElementDecl *getEnumValue() const;
310+
311+
SymbolicValue getEnumPayloadValue() const;
312+
274313
/// Return a symbolic value that represents the address of a memory object.
275314
static SymbolicValue getAddress(SymbolicValueMemoryObject *memoryObject) {
276315
SymbolicValue result;

lib/SIL/SILConstants.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ void SymbolicValue::print(llvm::raw_ostream &os, unsigned indent) const {
8282
return;
8383
}
8484
}
85+
case RK_Enum: {
86+
auto *decl = getEnumValue();
87+
os << "enum: ";
88+
decl->print(os);
89+
return;
90+
}
91+
case RK_EnumWithPayload: {
92+
auto *decl = getEnumValue();
93+
os << "enum: ";
94+
decl->print(os);
95+
os << ", payload: ";
96+
getEnumPayloadValue().print(os, indent);
97+
return;
98+
}
8599
case RK_DirectAddress:
86100
case RK_DerivedAddress: {
87101
SmallVector<unsigned, 4> accessPath;
@@ -111,6 +125,10 @@ SymbolicValue::Kind SymbolicValue::getKind() const {
111125
return Function;
112126
case RK_Aggregate:
113127
return Aggregate;
128+
case RK_Enum:
129+
return Enum;
130+
case RK_EnumWithPayload:
131+
return EnumWithPayload;
114132
case RK_Integer:
115133
case RK_IntegerInline:
116134
return Integer;
@@ -133,6 +151,9 @@ SymbolicValue::cloneInto(ASTContext &astContext) const {
133151
case RK_Metatype:
134152
case RK_Function:
135153
assert(0 && "cloning this representation kind is not supported");
154+
case RK_Enum:
155+
// These have trivial inline storage, just return a copy.
156+
return *this;
136157
case RK_IntegerInline:
137158
case RK_Integer:
138159
return SymbolicValue::getInteger(getIntegerValue(), astContext);
@@ -146,6 +167,8 @@ SymbolicValue::cloneInto(ASTContext &astContext) const {
146167
results.push_back(elt.cloneInto(astContext));
147168
return getAggregate(results, astContext);
148169
}
170+
case RK_EnumWithPayload:
171+
return getEnumWithPayload(getEnumValue(), getEnumPayloadValue(), astContext);
149172
case RK_DirectAddress:
150173
case RK_DerivedAddress: {
151174
SmallVector<unsigned, 4> accessPath;
@@ -354,6 +377,56 @@ UnknownReason SymbolicValue::getUnknownReason() const {
354377
return value.unknown->reason;
355378
}
356379

380+
//===----------------------------------------------------------------------===//
381+
// Enums
382+
//===----------------------------------------------------------------------===//
383+
384+
namespace swift {
385+
386+
/// This is the representation of a constant enum value with payload.
387+
struct EnumWithPayloadSymbolicValue final {
388+
/// The enum case.
389+
EnumElementDecl *enumDecl;
390+
SymbolicValue payload;
391+
392+
EnumWithPayloadSymbolicValue(EnumElementDecl *decl, SymbolicValue payload)
393+
: enumDecl(decl), payload(payload) {}
394+
395+
private:
396+
EnumWithPayloadSymbolicValue() = delete;
397+
EnumWithPayloadSymbolicValue(const EnumWithPayloadSymbolicValue &) = delete;
398+
};
399+
} // end namespace swift
400+
401+
/// This returns a constant Symbolic value for the enum case in `decl` with a
402+
/// payload.
403+
SymbolicValue
404+
SymbolicValue::getEnumWithPayload(EnumElementDecl *decl, SymbolicValue payload,
405+
ASTContext &astContext) {
406+
assert(decl && payload.isConstant());
407+
auto rawMem = astContext.Allocate(sizeof(EnumWithPayloadSymbolicValue),
408+
alignof(EnumWithPayloadSymbolicValue));
409+
auto enumVal = ::new (rawMem) EnumWithPayloadSymbolicValue(decl, payload);
410+
411+
SymbolicValue result;
412+
result.representationKind = RK_EnumWithPayload;
413+
result.value.enumValWithPayload = enumVal;
414+
return result;
415+
}
416+
417+
EnumElementDecl *SymbolicValue::getEnumValue() const {
418+
if (representationKind == RK_Enum)
419+
return value.enumVal;
420+
421+
assert(representationKind == RK_EnumWithPayload);
422+
return value.enumValWithPayload->enumDecl;
423+
}
424+
425+
SymbolicValue SymbolicValue::getEnumPayloadValue() const {
426+
assert(representationKind == RK_EnumWithPayload);
427+
return value.enumValWithPayload->payload;
428+
}
429+
357430
//===----------------------------------------------------------------------===//
358431
// Addresses
359432
//===----------------------------------------------------------------------===//

lib/SILOptimizer/Utils/ConstExpr.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,25 @@ SymbolicValue ConstExprFunctionState::computeConstantValue(SILValue value) {
287287
return calculatedValues[apply];
288288
}
289289

290+
if (auto *enumVal = dyn_cast<EnumInst>(value)) {
291+
if (!enumVal->hasOperand())
292+
return SymbolicValue::getEnum(enumVal->getElement());
293+
294+
auto payload = getConstantValue(enumVal->getOperand());
295+
if (!payload.isConstant())
296+
return payload;
297+
return SymbolicValue::getEnumWithPayload(enumVal->getElement(), payload,
298+
evaluator.getASTContext());
299+
}
300+
301+
// This one returns the address of its enum payload.
302+
if (auto *dai = dyn_cast<UncheckedTakeEnumDataAddrInst>(value)) {
303+
auto enumVal = getConstAddrAndLoadResult(dai->getOperand());
304+
if (!enumVal.isConstant())
305+
return enumVal;
306+
return createMemoryObject(value, enumVal.getEnumPayloadValue());
307+
}
308+
290309
// This instruction is a marker that returns its first operand.
291310
if (auto *bai = dyn_cast<BeginAccessInst>(value))
292311
return getConstantValue(bai->getOperand());
@@ -1244,6 +1263,34 @@ static llvm::Optional<SymbolicValue> evaluateAndCacheCall(
12441263
continue;
12451264
}
12461265

1266+
if (isa<SwitchEnumAddrInst>(inst) || isa<SwitchEnumInst>(inst)) {
1267+
SymbolicValue value;
1268+
SwitchEnumInstBase *switchInst = dyn_cast<SwitchEnumInst>(inst);
1269+
if (switchInst) {
1270+
value = state.getConstantValue(switchInst->getOperand());
1271+
} else {
1272+
switchInst = cast<SwitchEnumAddrInst>(inst);
1273+
value = state.getConstAddrAndLoadResult(switchInst->getOperand());
1274+
}
1275+
if (!value.isConstant())
1276+
return value;
1277+
assert(value.getKind() == SymbolicValue::Enum ||
1278+
value.getKind() == SymbolicValue::EnumWithPayload);
1279+
// Set up basic block arguments.
1280+
auto *caseBB = switchInst->getCaseDestination(value.getEnumValue());
1281+
if (caseBB->getNumArguments() > 0) {
1282+
assert(value.getKind() == SymbolicValue::EnumWithPayload);
1283+
// When there are multiple payload components, they form a single
1284+
// tuple-typed argument.
1285+
assert(caseBB->getNumArguments() == 1);
1286+
auto argument = value.getEnumPayloadValue();
1287+
assert(argument.isConstant());
1288+
state.setValue(caseBB->getArgument(0), argument);
1289+
}
1290+
nextInst = caseBB->begin();
1291+
continue;
1292+
}
1293+
12471294
LLVM_DEBUG(llvm::dbgs()
12481295
<< "ConstExpr: Unknown Terminator: " << *inst << "\n");
12491296

test/SILOptimizer/pound_assert.swift

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,91 @@ func stringInitNonEmptyFlowSensitive() -> ContainsString {
497497
func invokeStringInitNonEmptyFlowSensitive() {
498498
#assert(stringInitNonEmptyFlowSensitive().x == 1)
499499
}
500+
501+
//===----------------------------------------------------------------------===//
502+
// Enums and optionals.
503+
//===----------------------------------------------------------------------===//
504+
func isNil(_ x: Int?) -> Bool {
505+
return x == nil
506+
}
507+
508+
#assert(isNil(nil))
509+
#assert(!isNil(3))
510+
511+
public enum Pet {
512+
case bird
513+
case cat(Int)
514+
case dog(Int, Int)
515+
case fish
516+
}
517+
518+
public func weighPet(pet: Pet) -> Int {
519+
switch pet {
520+
case .bird: return 3
521+
case let .cat(weight): return weight
522+
case let .dog(w1, w2): return w1+w2
523+
default: return 1
524+
}
525+
}
526+
527+
#assert(weighPet(pet: .bird) == 3)
528+
#assert(weighPet(pet: .fish) == 1)
529+
#assert(weighPet(pet: .cat(2)) == 2)
530+
// expected-error @+1 {{assertion failed}}
531+
#assert(weighPet(pet: .cat(2)) == 3)
532+
#assert(weighPet(pet: .dog(9, 10)) == 19)
533+
534+
// Test indirect enums.
535+
indirect enum IntExpr {
536+
case int(_ value: Int)
537+
case add(_ lhs: IntExpr, _ rhs: IntExpr)
538+
case multiply(_ lhs: IntExpr, _ rhs: IntExpr)
539+
}
540+
541+
func evaluate(intExpr: IntExpr) -> Int {
542+
switch intExpr {
543+
case .int(let value):
544+
return value
545+
case .add(let lhs, let rhs):
546+
return evaluate(intExpr: lhs) + evaluate(intExpr: rhs)
547+
case .multiply(let lhs, let rhs):
548+
return evaluate(intExpr: lhs) * evaluate(intExpr: rhs)
549+
}
550+
}
551+
552+
// TODO: The constant evaluator can't handle indirect enums yet.
553+
// expected-error @+2 {{#assert condition not constant}}
554+
// expected-note @+1 {{could not fold operation}}
555+
#assert(evaluate(intExpr: .int(5)) == 5)
556+
// expected-error @+2 {{#assert condition not constant}}
557+
// expected-note @+1 {{could not fold operation}}
558+
#assert(evaluate(intExpr: .add(.int(5), .int(6))) == 11)
559+
// expected-error @+2 {{#assert condition not constant}}
560+
// expected-note @+1 {{could not fold operation}}
561+
#assert(evaluate(intExpr: .add(.multiply(.int(2), .int(2)), .int(3))) == 7)
562+
563+
// Test address-only enums.
564+
protocol IntContainerProtocol {
565+
var value: Int { get }
566+
}
567+
568+
struct IntContainer : IntContainerProtocol {
569+
let value: Int
570+
}
571+
572+
enum AddressOnlyEnum<T: IntContainerProtocol> {
573+
case double(_ value: T)
574+
case triple(_ value: T)
575+
}
576+
577+
func evaluate<T>(addressOnlyEnum: AddressOnlyEnum<T>) -> Int {
578+
switch addressOnlyEnum {
579+
case .double(let value):
580+
return 2 * value.value
581+
case .triple(let value):
582+
return 3 * value.value
583+
}
584+
}
585+
586+
#assert(evaluate(addressOnlyEnum: .double(IntContainer(value: 1))) == 2)
587+
#assert(evaluate(addressOnlyEnum: .triple(IntContainer(value: 1))) == 3)

0 commit comments

Comments
 (0)