Skip to content

Commit 3defefd

Browse files
feat: support more types in switch statements (#2926)
1 parent 4e5fe9c commit 3defefd

File tree

4 files changed

+8096
-126
lines changed

4 files changed

+8096
-126
lines changed

src/compiler.ts

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2816,19 +2816,18 @@ export class Compiler extends DiagnosticEmitter {
28162816
let numCases = cases.length;
28172817

28182818
// Compile the condition (always executes)
2819-
let condExpr = this.compileExpression(statement.condition, Type.u32,
2820-
Constraints.ConvImplicit
2821-
);
2822-
2819+
let condExpr = this.compileExpression(statement.condition, Type.auto);
2820+
let condType = this.currentType;
2821+
28232822
// Shortcut if there are no cases
28242823
if (!numCases) return module.drop(condExpr);
28252824

28262825
// Assign the condition to a temporary local as we compare it multiple times
28272826
let outerFlow = this.currentFlow;
2828-
let tempLocal = outerFlow.getTempLocal(Type.u32);
2827+
let tempLocal = outerFlow.getTempLocal(condType);
28292828
let tempLocalIndex = tempLocal.index;
28302829
let breaks = new Array<ExpressionRef>(1 + numCases);
2831-
breaks[0] = module.local_set(tempLocalIndex, condExpr, false); // u32
2830+
breaks[0] = module.local_set(tempLocalIndex, condExpr, condType.isManaged);
28322831

28332832
// Make one br_if per labeled case and leave it to Binaryen to optimize the
28342833
// sequence of br_ifs to a br_table according to optimization levels
@@ -2841,14 +2840,24 @@ export class Compiler extends DiagnosticEmitter {
28412840
defaultIndex = i;
28422841
continue;
28432842
}
2844-
breaks[breakIndex++] = module.br(`case${i}|${label}`,
2845-
module.binary(BinaryOp.EqI32,
2846-
module.local_get(tempLocalIndex, TypeRef.I32),
2847-
this.compileExpression(assert(case_.label), Type.u32,
2848-
Constraints.ConvImplicit
2849-
)
2850-
)
2843+
2844+
// Compile the equality expression for this case
2845+
const left = statement.condition;
2846+
const leftExpr = module.local_get(tempLocalIndex, condType.toRef());
2847+
const leftType = condType;
2848+
const right = case_.label!;
2849+
const rightExpr = this.compileExpression(assert(case_.label), condType, Constraints.ConvImplicit);
2850+
const rightType = this.currentType;
2851+
const equalityExpr = this.compileCommutativeCompareBinaryExpressionFromParts(
2852+
Token.Equals_Equals,
2853+
left, leftExpr, leftType,
2854+
right, rightExpr, rightType,
2855+
condType,
2856+
statement
28512857
);
2858+
2859+
// Add it to the list of breaks
2860+
breaks[breakIndex++] = module.br(`case${i}|${label}`, equalityExpr);
28522861
}
28532862

28542863
// If there is a default case, break to it, otherwise break out of the switch
@@ -3800,32 +3809,53 @@ export class Compiler extends DiagnosticEmitter {
38003809
expression: BinaryExpression,
38013810
contextualType: Type,
38023811
): ExpressionRef {
3803-
let module = this.module;
3804-
let left = expression.left;
3805-
let right = expression.right;
3812+
3813+
const left = expression.left;
3814+
const leftExpr = this.compileExpression(left, contextualType);
3815+
const leftType = this.currentType;
3816+
3817+
const right = expression.right;
3818+
const rightExpr = this.compileExpression(right, leftType);
3819+
const rightType = this.currentType;
3820+
3821+
return this.compileCommutativeCompareBinaryExpressionFromParts(
3822+
expression.operator,
3823+
left, leftExpr, leftType,
3824+
right, rightExpr, rightType,
3825+
contextualType,
3826+
expression
3827+
);
3828+
}
38063829

3807-
let leftExpr: ExpressionRef;
3808-
let leftType: Type;
3809-
let rightExpr: ExpressionRef;
3810-
let rightType: Type;
3811-
let commonType: Type | null;
3830+
/**
3831+
* compile `==` `===` `!=` `!==` BinaryExpression, from previously compiled left and right expressions.
3832+
*
3833+
* This is split from `compileCommutativeCompareBinaryExpression` so that the logic can be reused
3834+
* for switch cases in `compileSwitchStatement`, where the left expression only should be compiled once.
3835+
*/
3836+
private compileCommutativeCompareBinaryExpressionFromParts(
3837+
operator: Token,
3838+
left: Expression,
3839+
leftExpr: ExpressionRef,
3840+
leftType: Type,
3841+
right: Expression,
3842+
rightExpr: ExpressionRef,
3843+
rightType: Type,
3844+
contextualType: Type,
3845+
reportNode: Node
3846+
): ExpressionRef {
38123847

3813-
let operator = expression.operator;
3848+
let module = this.module;
38143849
let operatorString = operatorTokenToString(operator);
3815-
3816-
leftExpr = this.compileExpression(left, contextualType);
3817-
leftType = this.currentType;
3818-
3819-
rightExpr = this.compileExpression(right, leftType);
3820-
rightType = this.currentType;
38213850

38223851
// check operator overload
38233852
const operatorKind = OperatorKind.fromBinaryToken(operator);
38243853
const leftOverload = leftType.lookupOverload(operatorKind, this.program);
38253854
const rightOverload = rightType.lookupOverload(operatorKind, this.program);
38263855
if (leftOverload && rightOverload && leftOverload != rightOverload) {
38273856
this.error(
3828-
DiagnosticCode.Ambiguous_operator_overload_0_conflicting_overloads_1_and_2, expression.range,
3857+
DiagnosticCode.Ambiguous_operator_overload_0_conflicting_overloads_1_and_2,
3858+
reportNode.range,
38293859
operatorString,
38303860
leftOverload.internalName,
38313861
rightOverload.internalName
@@ -3838,23 +3868,23 @@ export class Compiler extends DiagnosticEmitter {
38383868
leftOverload,
38393869
left, leftExpr, leftType,
38403870
right, rightExpr, rightType,
3841-
expression
3871+
reportNode
38423872
);
38433873
}
38443874
if (rightOverload) {
38453875
return this.compileCommutativeBinaryOverload(
38463876
rightOverload,
38473877
right, rightExpr, rightType,
38483878
left, leftExpr, leftType,
3849-
expression
3879+
reportNode
38503880
);
38513881
}
38523882
const signednessIsRelevant = false;
3853-
commonType = Type.commonType(leftType, rightType, contextualType, signednessIsRelevant);
3883+
const commonType = Type.commonType(leftType, rightType, contextualType, signednessIsRelevant);
38543884
if (!commonType) {
38553885
this.error(
38563886
DiagnosticCode.Operator_0_cannot_be_applied_to_types_1_and_2,
3857-
expression.range,
3887+
reportNode.range,
38583888
operatorString,
38593889
leftType.toString(),
38603890
rightType.toString()
@@ -3867,13 +3897,13 @@ export class Compiler extends DiagnosticEmitter {
38673897
if (isConstExpressionNaN(module, rightExpr) || isConstExpressionNaN(module, leftExpr)) {
38683898
this.warning(
38693899
DiagnosticCode._NaN_does_not_compare_equal_to_any_other_value_including_itself_Use_isNaN_x_instead,
3870-
expression.range
3900+
reportNode.range
38713901
);
38723902
}
38733903
if (isConstNegZero(rightExpr) || isConstNegZero(leftExpr)) {
38743904
this.warning(
38753905
DiagnosticCode.Comparison_with_0_0_is_sign_insensitive_Use_Object_is_x_0_0_if_the_sign_matters,
3876-
expression.range
3906+
reportNode.range
38773907
);
38783908
}
38793909
}
@@ -3887,10 +3917,10 @@ export class Compiler extends DiagnosticEmitter {
38873917
switch (operator) {
38883918
case Token.Equals_Equals_Equals:
38893919
case Token.Equals_Equals:
3890-
return this.makeEq(leftExpr, rightExpr, commonType, expression);
3920+
return this.makeEq(leftExpr, rightExpr, commonType, reportNode);
38913921
case Token.Exclamation_Equals_Equals:
38923922
case Token.Exclamation_Equals:
3893-
return this.makeNe(leftExpr, rightExpr, commonType, expression);
3923+
return this.makeNe(leftExpr, rightExpr, commonType, reportNode);
38943924
default:
38953925
assert(false);
38963926
return module.unreachable();

0 commit comments

Comments
 (0)