Skip to content

[WebAssembly] Support type checker for new EH #111069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 7, 2024
Merged

Conversation

aheejin
Copy link
Member

@aheejin aheejin commented Oct 3, 2024

This adds supports for the new EH instructions (try_table and throw_ref) to the type checker.

One thing I'd like to improve on is the locations in the errors for catch_*** clauses. Currently they just point to the starting column of try_table instruction itself. But to figure out where catch clauses start you need to traverse OperandVector and check WebAssemblyOperand::isCatchList on them to see which one is the catch list operand, but WebAssemblyOperand class is in AsmParser and AsmTypeCheck does not have access to it:

namespace {
/// WebAssemblyOperand - Instances of this class represent the operands in a
/// parsed Wasm machine instruction.
struct WebAssemblyOperand : public MCParsedAsmOperand {
enum KindTy { Token, Integer, Float, Symbol, BrList, CatchList } Kind;
SMLoc StartLoc, EndLoc;
struct TokOp {
StringRef Tok;
};
struct IntOp {
int64_t Val;
};
struct FltOp {
double Val;
};
struct SymOp {
const MCExpr *Exp;
};
struct BrLOp {
std::vector<unsigned> List;
};
struct CaLOpElem {
uint8_t Opcode;
const MCExpr *Tag;
unsigned Dest;
};
struct CaLOp {
std::vector<CaLOpElem> List;
};
union {
struct TokOp Tok;
struct IntOp Int;
struct FltOp Flt;
struct SymOp Sym;
struct BrLOp BrL;
struct CaLOp CaL;
};
WebAssemblyOperand(SMLoc Start, SMLoc End, TokOp T)
: Kind(Token), StartLoc(Start), EndLoc(End), Tok(T) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, IntOp I)
: Kind(Integer), StartLoc(Start), EndLoc(End), Int(I) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, FltOp F)
: Kind(Float), StartLoc(Start), EndLoc(End), Flt(F) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, SymOp S)
: Kind(Symbol), StartLoc(Start), EndLoc(End), Sym(S) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, BrLOp B)
: Kind(BrList), StartLoc(Start), EndLoc(End), BrL(B) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, CaLOp C)
: Kind(CatchList), StartLoc(Start), EndLoc(End), CaL(C) {}
~WebAssemblyOperand() {
if (isBrList())
BrL.~BrLOp();
if (isCatchList())
CaL.~CaLOp();
}
bool isToken() const override { return Kind == Token; }
bool isImm() const override { return Kind == Integer || Kind == Symbol; }
bool isFPImm() const { return Kind == Float; }
bool isMem() const override { return false; }
bool isReg() const override { return false; }
bool isBrList() const { return Kind == BrList; }
bool isCatchList() const { return Kind == CatchList; }
MCRegister getReg() const override {
llvm_unreachable("Assembly inspects a register operand");
return 0;
}
StringRef getToken() const {
assert(isToken());
return Tok.Tok;
}
SMLoc getStartLoc() const override { return StartLoc; }
SMLoc getEndLoc() const override { return EndLoc; }
void addRegOperands(MCInst &, unsigned) const {
// Required by the assembly matcher.
llvm_unreachable("Assembly matcher creates register operands");
}
void addImmOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
if (Kind == Integer)
Inst.addOperand(MCOperand::createImm(Int.Val));
else if (Kind == Symbol)
Inst.addOperand(MCOperand::createExpr(Sym.Exp));
else
llvm_unreachable("Should be integer immediate or symbol!");
}
void addFPImmf32Operands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
if (Kind == Float)
Inst.addOperand(
MCOperand::createSFPImm(bit_cast<uint32_t>(float(Flt.Val))));
else
llvm_unreachable("Should be float immediate!");
}
void addFPImmf64Operands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
if (Kind == Float)
Inst.addOperand(MCOperand::createDFPImm(bit_cast<uint64_t>(Flt.Val)));
else
llvm_unreachable("Should be float immediate!");
}
void addBrListOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && isBrList() && "Invalid BrList!");
for (auto Br : BrL.List)
Inst.addOperand(MCOperand::createImm(Br));
}
void addCatchListOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && isCatchList() && "Invalid CatchList!");
Inst.addOperand(MCOperand::createImm(CaL.List.size()));
for (auto Ca : CaL.List) {
Inst.addOperand(MCOperand::createImm(Ca.Opcode));
if (Ca.Opcode == wasm::WASM_OPCODE_CATCH ||
Ca.Opcode == wasm::WASM_OPCODE_CATCH_REF)
Inst.addOperand(MCOperand::createExpr(Ca.Tag));
Inst.addOperand(MCOperand::createImm(Ca.Dest));
}
}
void print(raw_ostream &OS) const override {
switch (Kind) {
case Token:
OS << "Tok:" << Tok.Tok;
break;
case Integer:
OS << "Int:" << Int.Val;
break;
case Float:
OS << "Flt:" << Flt.Val;
break;
case Symbol:
OS << "Sym:" << Sym.Exp;
break;
case BrList:
OS << "BrList:" << BrL.List.size();
break;
case CatchList:
OS << "CaList:" << CaL.List.size();
break;
}
}
};
And even if AsmTypeCheck has access to it, currently it treats the list of catch clauses as a single WebAssemblyOperand so there is no way to get the starting location of each catch_*** clause in the current structure.

This also renames valTypeToStackType to valTypesToStackTypes, given that it takes two type lists.

This adds supports for the new EH instructions (`try_table` and
`throw_ref`) to the type checker.

One thing I'd like to improve on is the locations in the errors for
`catch_***` clauses. Currently they just point to the starting column of
`try_table` instruction itself. But to figure out where catch clauses
start you need to traverse `OperandVector` and check
`WebAssemblyOperand::isCatchList` on them to see which one is the catch
list operand, but `WebAssemblyOperand` class is in AsmParser and
AsmTypeCheck does not have access to it:
https://github.com/llvm/llvm-project/blob/cdfdc857cbab0418b7e5116fd4255eb5566588bd/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp#L43-L204
And even if AsmTypeCheck has access to it, currently it treats the list
of catch clauses as a single `WebAssemblyOperand` so there is no way to
get the starting location of each `catch_***` clause in the current
structure.

This also renames `valTypeToStackType` to `valTypesToStackTypes`, given
that it takes two type lists.
@aheejin aheejin requested a review from dschuff October 3, 2024 22:28
@llvmbot llvmbot added backend:WebAssembly mc Machine (object) code labels Oct 3, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2024

@llvm/pr-subscribers-mc

@llvm/pr-subscribers-backend-webassembly

Author: Heejin Ahn (aheejin)

Changes

This adds supports for the new EH instructions (try_table and throw_ref) to the type checker.

One thing I'd like to improve on is the locations in the errors for catch_*** clauses. Currently they just point to the starting column of try_table instruction itself. But to figure out where catch clauses start you need to traverse OperandVector and check WebAssemblyOperand::isCatchList on them to see which one is the catch list operand, but WebAssemblyOperand class is in AsmParser and AsmTypeCheck does not have access to it:

namespace {
/// WebAssemblyOperand - Instances of this class represent the operands in a
/// parsed Wasm machine instruction.
struct WebAssemblyOperand : public MCParsedAsmOperand {
enum KindTy { Token, Integer, Float, Symbol, BrList, CatchList } Kind;
SMLoc StartLoc, EndLoc;
struct TokOp {
StringRef Tok;
};
struct IntOp {
int64_t Val;
};
struct FltOp {
double Val;
};
struct SymOp {
const MCExpr *Exp;
};
struct BrLOp {
std::vector<unsigned> List;
};
struct CaLOpElem {
uint8_t Opcode;
const MCExpr *Tag;
unsigned Dest;
};
struct CaLOp {
std::vector<CaLOpElem> List;
};
union {
struct TokOp Tok;
struct IntOp Int;
struct FltOp Flt;
struct SymOp Sym;
struct BrLOp BrL;
struct CaLOp CaL;
};
WebAssemblyOperand(SMLoc Start, SMLoc End, TokOp T)
: Kind(Token), StartLoc(Start), EndLoc(End), Tok(T) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, IntOp I)
: Kind(Integer), StartLoc(Start), EndLoc(End), Int(I) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, FltOp F)
: Kind(Float), StartLoc(Start), EndLoc(End), Flt(F) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, SymOp S)
: Kind(Symbol), StartLoc(Start), EndLoc(End), Sym(S) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, BrLOp B)
: Kind(BrList), StartLoc(Start), EndLoc(End), BrL(B) {}
WebAssemblyOperand(SMLoc Start, SMLoc End, CaLOp C)
: Kind(CatchList), StartLoc(Start), EndLoc(End), CaL(C) {}
~WebAssemblyOperand() {
if (isBrList())
BrL.~BrLOp();
if (isCatchList())
CaL.~CaLOp();
}
bool isToken() const override { return Kind == Token; }
bool isImm() const override { return Kind == Integer || Kind == Symbol; }
bool isFPImm() const { return Kind == Float; }
bool isMem() const override { return false; }
bool isReg() const override { return false; }
bool isBrList() const { return Kind == BrList; }
bool isCatchList() const { return Kind == CatchList; }
MCRegister getReg() const override {
llvm_unreachable("Assembly inspects a register operand");
return 0;
}
StringRef getToken() const {
assert(isToken());
return Tok.Tok;
}
SMLoc getStartLoc() const override { return StartLoc; }
SMLoc getEndLoc() const override { return EndLoc; }
void addRegOperands(MCInst &, unsigned) const {
// Required by the assembly matcher.
llvm_unreachable("Assembly matcher creates register operands");
}
void addImmOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
if (Kind == Integer)
Inst.addOperand(MCOperand::createImm(Int.Val));
else if (Kind == Symbol)
Inst.addOperand(MCOperand::createExpr(Sym.Exp));
else
llvm_unreachable("Should be integer immediate or symbol!");
}
void addFPImmf32Operands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
if (Kind == Float)
Inst.addOperand(
MCOperand::createSFPImm(bit_cast<uint32_t>(float(Flt.Val))));
else
llvm_unreachable("Should be float immediate!");
}
void addFPImmf64Operands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
if (Kind == Float)
Inst.addOperand(MCOperand::createDFPImm(bit_cast<uint64_t>(Flt.Val)));
else
llvm_unreachable("Should be float immediate!");
}
void addBrListOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && isBrList() && "Invalid BrList!");
for (auto Br : BrL.List)
Inst.addOperand(MCOperand::createImm(Br));
}
void addCatchListOperands(MCInst &Inst, unsigned N) const {
assert(N == 1 && isCatchList() && "Invalid CatchList!");
Inst.addOperand(MCOperand::createImm(CaL.List.size()));
for (auto Ca : CaL.List) {
Inst.addOperand(MCOperand::createImm(Ca.Opcode));
if (Ca.Opcode == wasm::WASM_OPCODE_CATCH ||
Ca.Opcode == wasm::WASM_OPCODE_CATCH_REF)
Inst.addOperand(MCOperand::createExpr(Ca.Tag));
Inst.addOperand(MCOperand::createImm(Ca.Dest));
}
}
void print(raw_ostream &OS) const override {
switch (Kind) {
case Token:
OS << "Tok:" << Tok.Tok;
break;
case Integer:
OS << "Int:" << Int.Val;
break;
case Float:
OS << "Flt:" << Flt.Val;
break;
case Symbol:
OS << "Sym:" << Sym.Exp;
break;
case BrList:
OS << "BrList:" << BrL.List.size();
break;
case CatchList:
OS << "CaList:" << CaL.List.size();
break;
}
}
};
And even if AsmTypeCheck has access to it, currently it treats the list of catch clauses as a single WebAssemblyOperand so there is no way to get the starting location of each catch_*** clause in the current structure.

This also renames valTypeToStackType to valTypesToStackTypes, given that it takes two type lists.


Full diff: https://github.com/llvm/llvm-project/pull/111069.diff

4 Files Affected:

  • (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp (+82-9)
  • (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h (+5-2)
  • (modified) llvm/test/MC/WebAssembly/eh-assembly.s (+2-2)
  • (modified) llvm/test/MC/WebAssembly/type-checker-errors.s (+23)
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index effc2e65223cad..f01e19962ab9fc 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -59,7 +59,7 @@ void WebAssemblyAsmTypeCheck::localDecl(
 }
 
 void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
-  LLVM_DEBUG({ dbgs() << Msg << getTypesString(Stack, 0) << "\n"; });
+  LLVM_DEBUG({ dbgs() << Msg << getTypesString(Stack) << "\n"; });
 }
 
 bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
@@ -116,8 +116,15 @@ std::string WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<StackType> Types,
   return SS.str();
 }
 
+std::string
+WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<wasm::ValType> Types,
+                                        size_t StartPos) {
+  return getTypesString(valTypesToStackTypes(Types), StartPos);
+}
+
 SmallVector<WebAssemblyAsmTypeCheck::StackType, 4>
-WebAssemblyAsmTypeCheck::valTypeToStackType(ArrayRef<wasm::ValType> ValTypes) {
+WebAssemblyAsmTypeCheck::valTypesToStackTypes(
+    ArrayRef<wasm::ValType> ValTypes) {
   SmallVector<StackType, 4> Types(ValTypes.size());
   std::transform(ValTypes.begin(), ValTypes.end(), Types.begin(),
                  [](wasm::ValType Val) -> StackType { return Val; });
@@ -127,7 +134,7 @@ WebAssemblyAsmTypeCheck::valTypeToStackType(ArrayRef<wasm::ValType> ValTypes) {
 bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
                                          ArrayRef<wasm::ValType> ValTypes,
                                          bool ExactMatch) {
-  return checkTypes(ErrorLoc, valTypeToStackType(ValTypes), ExactMatch);
+  return checkTypes(ErrorLoc, valTypesToStackTypes(ValTypes), ExactMatch);
 }
 
 bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
@@ -178,14 +185,14 @@ bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
                            : std::max((int)BlockStackStartPos,
                                       (int)Stack.size() - (int)Types.size());
   return typeError(ErrorLoc, "type mismatch, expected " +
-                                 getTypesString(Types, 0) + " but got " +
+                                 getTypesString(Types) + " but got " +
                                  getTypesString(Stack, StackStartPos));
 }
 
 bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
                                        ArrayRef<wasm::ValType> ValTypes,
                                        bool ExactMatch) {
-  return popTypes(ErrorLoc, valTypeToStackType(ValTypes), ExactMatch);
+  return popTypes(ErrorLoc, valTypesToStackTypes(ValTypes), ExactMatch);
 }
 
 bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
@@ -215,7 +222,7 @@ bool WebAssemblyAsmTypeCheck::popAnyType(SMLoc ErrorLoc) {
 }
 
 void WebAssemblyAsmTypeCheck::pushTypes(ArrayRef<wasm::ValType> ValTypes) {
-  Stack.append(valTypeToStackType(ValTypes));
+  Stack.append(valTypesToStackTypes(ValTypes));
 }
 
 bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
@@ -322,6 +329,63 @@ bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc, bool ExactMatch) {
   return checkTypes(ErrorLoc, FuncInfo.Sig.Returns, ExactMatch);
 }
 
+// Unlike checkTypes() family, this just compare the equivalence of the two
+// ValType vectors
+static bool compareTypes(ArrayRef<wasm::ValType> TypesA,
+                         ArrayRef<wasm::ValType> TypesB) {
+  if (TypesA.size() != TypesB.size())
+    return true;
+  for (size_t I = 0, E = TypesA.size(); I < E; I++)
+    if (TypesA[I] != TypesB[I])
+      return true;
+  return false;
+}
+
+bool WebAssemblyAsmTypeCheck::checkTryTable(SMLoc ErrorLoc,
+                                            const MCInst &Inst) {
+  bool Error = false;
+  unsigned OpIdx = 1; // OpIdx 0 is the block type
+  int64_t NumCatches = Inst.getOperand(OpIdx++).getImm();
+  for (int64_t I = 0; I < NumCatches; I++) {
+    int64_t Opcode = Inst.getOperand(OpIdx++).getImm();
+    std::string ErrorMsgBase =
+        "try_table: catch index " + std::to_string(I) + ": ";
+
+    const wasm::WasmSignature *Sig = nullptr;
+    SmallVector<wasm::ValType> SentTypes;
+    if (Opcode == wasm::WASM_OPCODE_CATCH ||
+        Opcode == wasm::WASM_OPCODE_CATCH_REF) {
+      if (!getSignature(ErrorLoc, Inst.getOperand(OpIdx++),
+                        wasm::WASM_SYMBOL_TYPE_TAG, Sig))
+        SentTypes.insert(SentTypes.end(), Sig->Params.begin(),
+                         Sig->Params.end());
+      else
+        Error = true;
+    }
+    if (Opcode == wasm::WASM_OPCODE_CATCH_REF ||
+        Opcode == wasm::WASM_OPCODE_CATCH_ALL_REF) {
+      SentTypes.push_back(wasm::ValType::EXNREF);
+    }
+
+    unsigned Level = Inst.getOperand(OpIdx++).getImm();
+    if (Level < BlockInfoStack.size()) {
+      const auto &DestBlockInfo =
+          BlockInfoStack[BlockInfoStack.size() - Level - 1];
+      if (compareTypes(SentTypes, DestBlockInfo.Sig.Returns)) {
+        std::string ErrorMsg =
+            ErrorMsgBase + "type mismatch, catch tag type is " +
+            getTypesString(SentTypes) + ", but destination's return type is " +
+            getTypesString(DestBlockInfo.Sig.Returns);
+        Error |= typeError(ErrorLoc, ErrorMsg);
+      }
+    } else {
+      Error = typeError(ErrorLoc, ErrorMsgBase + "invalid depth " +
+                                      std::to_string(Level));
+    }
+  }
+  return Error;
+}
+
 bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
                                         OperandVector &Operands) {
   auto Opc = Inst.getOpcode();
@@ -460,10 +524,13 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
     return popType(ErrorLoc, Any{});
   }
 
-  if (Name == "block" || Name == "loop" || Name == "if" || Name == "try") {
+  if (Name == "block" || Name == "loop" || Name == "if" || Name == "try" ||
+      Name == "try_table") {
     bool Error = Name == "if" && popType(ErrorLoc, wasm::ValType::I32);
     // Pop block input parameters and check their types are correct
     Error |= popTypes(ErrorLoc, LastSig.Params);
+    if (Name == "try_table")
+      Error |= checkTryTable(ErrorLoc, Inst);
     // Push a new block info
     BlockInfoStack.push_back({LastSig, Stack.size(), Name == "loop"});
     // Push back block input parameters
@@ -472,8 +539,8 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
   }
 
   if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
-      Name == "end_try" || Name == "delegate" || Name == "else" ||
-      Name == "catch" || Name == "catch_all") {
+      Name == "end_try" || Name == "delegate" || Name == "end_try_table" ||
+      Name == "else" || Name == "catch" || Name == "catch_all") {
     assert(!BlockInfoStack.empty());
     // Check if the types on the stack match with the block return type
     const auto &LastBlockInfo = BlockInfoStack.back();
@@ -586,6 +653,12 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
     return Error;
   }
 
+  if (Name == "throw_ref") {
+    bool Error = popType(ErrorLoc, wasm::ValType::EXNREF);
+    pushType(Polymorphic{});
+    return Error;
+  }
+
   // The current instruction is a stack instruction which doesn't have
   // explicit operands that indicate push/pop types, so we get those from
   // the register version of the same instruction.
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
index 596fb27bce94e6..e6fddf98060265 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
@@ -65,9 +65,11 @@ class WebAssemblyAsmTypeCheck final {
   void pushTypes(ArrayRef<wasm::ValType> Types);
   void pushType(StackType Type) { Stack.push_back(Type); }
   bool match(StackType TypeA, StackType TypeB);
-  std::string getTypesString(ArrayRef<StackType> Types, size_t StartPos);
+  std::string getTypesString(ArrayRef<wasm::ValType> Types,
+                             size_t StartPos = 0);
+  std::string getTypesString(ArrayRef<StackType> Types, size_t StartPos = 0);
   SmallVector<StackType, 4>
-  valTypeToStackType(ArrayRef<wasm::ValType> ValTypes);
+  valTypesToStackTypes(ArrayRef<wasm::ValType> ValTypes);
 
   void dumpTypeStack(Twine Msg);
   bool typeError(SMLoc ErrorLoc, const Twine &Msg);
@@ -80,6 +82,7 @@ class WebAssemblyAsmTypeCheck final {
   bool getTable(SMLoc ErrorLoc, const MCOperand &TableOp, wasm::ValType &Type);
   bool getSignature(SMLoc ErrorLoc, const MCOperand &SigOp,
                     wasm::WasmSymbolType Type, const wasm::WasmSignature *&Sig);
+  bool checkTryTable(SMLoc ErrorLoc, const MCInst &Inst);
 
 public:
   WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII,
diff --git a/llvm/test/MC/WebAssembly/eh-assembly.s b/llvm/test/MC/WebAssembly/eh-assembly.s
index b4d6b324d96e3e..a03c1b8e1aed14 100644
--- a/llvm/test/MC/WebAssembly/eh-assembly.s
+++ b/llvm/test/MC/WebAssembly/eh-assembly.s
@@ -1,6 +1,6 @@
-# RUN: llvm-mc -triple=wasm32-unknown-unknown -mattr=+exception-handling --no-type-check < %s | FileCheck %s
+# RUN: llvm-mc -triple=wasm32-unknown-unknown -mattr=+exception-handling < %s | FileCheck %s
 # Check that it converts to .o without errors, but don't check any output:
-# RUN: llvm-mc -triple=wasm32-unknown-unknown -filetype=obj -mattr=+exception-handling --no-type-check -o %t.o < %s
+# RUN: llvm-mc -triple=wasm32-unknown-unknown -filetype=obj -mattr=+exception-handling -o %t.o < %s
 
   .tagtype  __cpp_exception i32
   .tagtype  __c_longjmp i32
diff --git a/llvm/test/MC/WebAssembly/type-checker-errors.s b/llvm/test/MC/WebAssembly/type-checker-errors.s
index df537a9ba5d0a0..74ab17fdefdad9 100644
--- a/llvm/test/MC/WebAssembly/type-checker-errors.s
+++ b/llvm/test/MC/WebAssembly/type-checker-errors.s
@@ -944,3 +944,26 @@ block_param_and_return:
 
 # CHECK: :[[@LINE+1]]:3: error: type mismatch, expected [] but got [f32]
   end_function
+
+  .tagtype  __cpp_exception i32
+
+eh_test:
+  .functype eh_test () -> ()
+  block i32
+    block i32
+      block i32
+        block
+# CHECK: :[[@LINE+4]]:11: error: try_table: catch index 0: type mismatch, catch tag type is [i32], but destination's return type is []
+# CHECK: :[[@LINE+3]]:11: error: try_table: catch index 1: type mismatch, catch tag type is [i32, exnref], but destination's return type is [i32]
+# CHECK: :[[@LINE+2]]:11: error: try_table: catch index 2: type mismatch, catch tag type is [], but destination's return type is [i32]
+# CHECK: :[[@LINE+1]]:11: error: try_table: catch index 3: type mismatch, catch tag type is [exnref], but destination's return type is [i32]
+          try_table i32 (catch __cpp_exception 0) (catch_ref __cpp_exception 1) (catch_all 2) (catch_all_ref 3)
+# CHECK: :[[@LINE+1]]:11: error: type mismatch, expected [i32] but got []
+          end_try_table
+# CHECK: :[[@LINE+1]]:9: error: type mismatch, expected [] but got [i32]
+        end_block
+      end_block
+    end_block
+  end_block
+  drop
+  end_function

@@ -24,7 +24,6 @@ eh_test:
return
end_block
throw_ref
drop
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After throw_ref, the stack is polymorphic, so you don't need to drop the i32

@aheejin
Copy link
Member Author

aheejin commented Oct 3, 2024

Just realized that this does not support try_table targeting loops. Will do that as a follow-up.

@aheejin aheejin merged commit 69577b2 into llvm:main Oct 7, 2024
9 checks passed
@aheejin aheejin deleted the eh_typecheck branch October 7, 2024 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:WebAssembly mc Machine (object) code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants