diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h --- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h +++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h @@ -29,6 +29,7 @@ const MCInstrInfo &MII; SmallVector Stack; + SmallVector, 8> BrStack; SmallVector LocalTypes; SmallVector ReturnTypes; wasm::WasmSignature LastSig; @@ -42,6 +43,7 @@ bool popRefType(SMLoc ErrorLoc); bool getLocal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type); bool checkEnd(SMLoc ErrorLoc, bool PopVals = false); + bool checkBr(SMLoc ErrorLoc, size_t Level); bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig); bool getSymRef(SMLoc ErrorLoc, const MCInst &Inst, const MCSymbolRefExpr *&SymRef); @@ -60,6 +62,7 @@ void Clear() { Stack.clear(); + BrStack.clear(); LocalTypes.clear(); ReturnTypes.clear(); TypeErrorThisFunction = false; diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp --- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp +++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp @@ -34,6 +34,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Endian.h" #include "llvm/Support/SourceMgr.h" +#include using namespace llvm; @@ -51,6 +52,7 @@ void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) { LocalTypes.assign(Sig.Params.begin(), Sig.Params.end()); ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end()); + BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end()); } void WebAssemblyAsmTypeCheck::localDecl( @@ -122,7 +124,36 @@ return false; } +static std::optional +checkStackTop(const SmallVectorImpl &ExpectedStackTop, + const SmallVectorImpl &Got) { + for (size_t I = 0; I < ExpectedStackTop.size(); I++) { + auto EVT = ExpectedStackTop[I]; + auto PVT = Got[Got.size() - ExpectedStackTop.size() + I]; + if (PVT != EVT) + return std::string{"got "} + WebAssembly::typeToString(PVT) + + ", expected " + WebAssembly::typeToString(EVT); + } + return std::nullopt; +} + +bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) { + if (Level >= BrStack.size()) + return typeError(ErrorLoc, + StringRef("br: invalid depth ") + std::to_string(Level)); + const SmallVector &Expected = + BrStack[BrStack.size() - Level - 1]; + if (Expected.size() > Stack.size()) + return typeError(ErrorLoc, "br: insufficient values on the type stack"); + auto IsStackTopInvalid = checkStackTop(Expected, Stack); + if (IsStackTopInvalid) + return typeError(ErrorLoc, "br " + IsStackTopInvalid.value()); + return false; +} + bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) { + if (!PopVals) + BrStack.pop_back(); if (LastSig.Returns.size() > Stack.size()) return typeError(ErrorLoc, "end: insufficient values on the type stack"); @@ -134,14 +165,9 @@ return false; } - for (size_t i = 0; i < LastSig.Returns.size(); i++) { - auto EVT = LastSig.Returns[i]; - auto PVT = Stack[Stack.size() - LastSig.Returns.size() + i]; - if (PVT != EVT) - return typeError(ErrorLoc, - StringRef("end got ") + WebAssembly::typeToString(PVT) + - ", expected " + WebAssembly::typeToString(EVT)); - } + auto IsStackTopInvalid = checkStackTop(LastSig.Returns, Stack); + if (IsStackTopInvalid) + return typeError(ErrorLoc, "end " + IsStackTopInvalid.value()); return false; } @@ -300,6 +326,14 @@ } else if (Name == "drop") { if (popType(ErrorLoc, {})) return true; + } else if (Name == "try" || Name == "block" || Name == "loop" || + Name == "if") { + if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32)) + return true; + if (Name == "loop") + BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end()); + else + BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end()); } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" || Name == "else" || Name == "end_try" || Name == "catch" || Name == "catch_all" || Name == "delegate") { @@ -321,6 +355,12 @@ // "params" part Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end()); } + } else if (Name == "br") { + const MCOperand &Operand = Inst.getOperand(0); + if (!Operand.isImm()) + return false; + if (checkBr(ErrorLoc, static_cast(Operand.getImm()))) + return true; } else if (Name == "return") { if (endOfFunction(ErrorLoc)) return true; diff --git a/llvm/test/MC/WebAssembly/type-checker-errors.s b/llvm/test/MC/WebAssembly/type-checker-errors.s --- a/llvm/test/MC/WebAssembly/type-checker-errors.s +++ b/llvm/test/MC/WebAssembly/type-checker-errors.s @@ -698,3 +698,85 @@ # CHECK: :[[@LINE+1]]:3: error: empty stack while popping value drop end_function + +br_invalid_type_loop: + .functype br_invalid_type_loop () -> () + i32.const 1 + loop (i32) -> (f32) + drop + f32.const 1.0 +# CHECK: :[[@LINE+1]]:5: error: br got f32, expected i32 + br 0 + end_loop + drop + end_function + +br_invalid_type_block: + .functype br_invalid_type_block () -> () + i32.const 1 + block (i32) -> (f32) +# CHECK: :[[@LINE+1]]:5: error: br got i32, expected f32 + br 0 + f32.const 1.0 + end_block + drop + end_function + +br_invalid_type_if: + .functype br_invalid_type_if () -> () + i32.const 1 + if f32 + f32.const 1.0 + else + i32.const 1 +# CHECK: :[[@LINE+1]]:5: error: br got i32, expected f32 + br 0 + end_if + drop + end_function + +br_invalid_type_try: + .functype br_invalid_type_try () -> () + try f32 + i32.const 1 +# CHECK: :[[@LINE+1]]:5: error: br got i32, expected f32 + br 0 + catch tag_i32 + end_try + drop + end_function + +br_invalid_type_catch: + .functype br_invalid_type_catch () -> () + try f32 + f32.const 1.0 + catch tag_i32 +# CHECK: :[[@LINE+1]]:5: error: br got i32, expected f32 + br 0 + end_try + drop + end_function + +br_invalid_type_catch_all: + .functype br_invalid_type_catch_all () -> () + try f32 + f32.const 1.0 + catch_all + i32.const 1 +# CHECK: :[[@LINE+1]]:5: error: br got i32, expected f32 + br 0 + end_try + drop + end_function + +br_invalid_depth: + .functype br_invalid_depth () -> () + block + block + block +# CHECK: :[[@LINE+1]]:5: error: br: invalid depth 10 + br 10 + end_block + end_block + end_block + end_function