diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h --- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h +++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h @@ -16,14 +16,38 @@ #ifndef LLVM_LIB_TARGET_WEBASSEMBLY_ASMPARSER_TYPECHECK_H #define LLVM_LIB_TARGET_WEBASSEMBLY_ASMPARSER_TYPECHECK_H +#include "llvm/ADT/SmallVector.h" #include "llvm/BinaryFormat/Wasm.h" #include "llvm/MC/MCInstrInfo.h" #include "llvm/MC/MCParser/MCAsmParser.h" #include "llvm/MC/MCParser/MCTargetAsmParser.h" #include "llvm/MC/MCSymbol.h" +#include "llvm/Support/SMLoc.h" namespace llvm { +class WebAssemblyAsmOperandFrame final { +public: + using OperandStack = SmallVector; + enum class FrameType { + Loop, + Block, + Try, + If, + }; + WebAssemblyAsmOperandFrame(FrameType FrameType, + const SmallVectorImpl &Stack, + const wasm::WasmSignature &Signature); + FrameType getType() const { return Type; } + OperandStack const &getEnterStack() const { return EnterStack; } + OperandStack applySignature() const; + +private: + const FrameType Type; + OperandStack EnterStack; + wasm::WasmSignature Sig; +}; + class WebAssemblyAsmTypeCheck final { MCAsmParser &Parser; const MCInstrInfo &MII; @@ -31,6 +55,7 @@ SmallVector Stack; SmallVector LocalTypes; SmallVector ReturnTypes; + SmallVector BrStack; wasm::WasmSignature LastSig; bool TypeErrorThisFunction = false; bool Unreachable = false; @@ -39,9 +64,14 @@ void dumpTypeStack(Twine Msg); bool typeError(SMLoc ErrorLoc, const Twine &Msg); bool popType(SMLoc ErrorLoc, std::optional EVT); + bool popType(SmallVectorImpl &Stack, SMLoc ErrorLoc, + std::optional EVT); bool popRefType(SMLoc ErrorLoc); bool getLocal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type); bool checkEnd(SMLoc ErrorLoc, bool PopVals = false); + bool checkBr(SMLoc ErrorLoc, const MCInst &Inst); + bool checkBlockIn(WebAssemblyAsmOperandFrame::FrameType Type, SMLoc ErrorLoc, + const MCInst &Inst); bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig); bool getSymRef(SMLoc ErrorLoc, const MCInst &Inst, const MCSymbolRefExpr *&SymRef); @@ -49,7 +79,8 @@ bool getTable(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type); public: - WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII, bool is64); + WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII, + bool is64); void funcDecl(const wasm::WasmSignature &Sig); void localDecl(const SmallVectorImpl &Locals); @@ -61,6 +92,7 @@ Stack.clear(); LocalTypes.clear(); ReturnTypes.clear(); + BrStack.clear(); TypeErrorThisFunction = false; Unreachable = false; } @@ -68,4 +100,4 @@ } // end namespace llvm -#endif // LLVM_LIB_TARGET_WEBASSEMBLY_ASMPARSER_TYPECHECK_H +#endif // LLVM_LIB_TARGET_WEBASSEMBLY_ASMPARSER_TYPECHECK_H diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp --- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp +++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp @@ -19,6 +19,8 @@ #include "MCTargetDesc/WebAssemblyTargetStreamer.h" #include "TargetInfo/WebAssemblyTargetInfo.h" #include "WebAssembly.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/BinaryFormat/Wasm.h" #include "llvm/MC/MCContext.h" #include "llvm/MC/MCExpr.h" #include "llvm/MC/MCInst.h" @@ -32,8 +34,11 @@ #include "llvm/MC/MCSymbolWasm.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/Endian.h" #include "llvm/Support/SourceMgr.h" +#include +#include using namespace llvm; @@ -43,11 +48,24 @@ namespace llvm { -WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser, - const MCInstrInfo &MII, bool is64) - : Parser(Parser), MII(MII), is64(is64) { +WebAssemblyAsmOperandFrame::WebAssemblyAsmOperandFrame( + FrameType FrameType, const SmallVectorImpl &Stack, + const wasm::WasmSignature &Signature) + : Type(FrameType), EnterStack(Stack.begin(), Stack.end()), Sig(Signature) {} + +WebAssemblyAsmOperandFrame::OperandStack +WebAssemblyAsmOperandFrame::applySignature() const { + OperandStack Stack{EnterStack}; + Stack.pop_back_n(Sig.Params.size()); + Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end()); + return Stack; } +WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser, + const MCInstrInfo &MII, + bool is64) + : Parser(Parser), MII(MII), is64(is64) {} + void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) { LocalTypes.assign(Sig.Params.begin(), Sig.Params.end()); ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end()); @@ -84,13 +102,18 @@ bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc, std::optional EVT) { - if (Stack.empty()) { + return popType(Stack, ErrorLoc, EVT); +} +bool WebAssemblyAsmTypeCheck::popType(SmallVectorImpl &PopStack, + SMLoc ErrorLoc, + std::optional EVT) { + if (PopStack.empty()) { return typeError(ErrorLoc, EVT ? StringRef("empty stack while popping ") + WebAssembly::typeToString(*EVT) : StringRef("empty stack while popping value")); } - auto PVT = Stack.pop_back_val(); + auto PVT = PopStack.pop_back_val(); if (EVT && *EVT != PVT) { return typeError(ErrorLoc, StringRef("popped ") + WebAssembly::typeToString(PVT) + @@ -117,38 +140,106 @@ auto Local = static_cast(Inst.getOperand(0).getImm()); if (Local >= LocalTypes.size()) return typeError(ErrorLoc, StringRef("no local type specified for index ") + - std::to_string(Local)); + std::to_string(Local)); Type = LocalTypes[Local]; return false; } +static std::string +printStackDiff(SmallVectorImpl const &Expected, + SmallVectorImpl const &Got) { + std::stringstream Msg; + size_t Index = 0; + for (; Index < Expected.size() && Index < Got.size(); Index++) { + if (Expected[Index] != Got[Index]) { + break; + } + } + Msg << "got [..."; + // got + for (size_t I = Index; I < Got.size(); I++) { + Msg << ", " << WebAssembly::typeToString(Got[I]); + } + Msg << "], expected [..."; + // expected + for (size_t I = Index; I < Expected.size(); I++) { + Msg << ", " << WebAssembly::typeToString(Expected[I]); + } + Msg << "]"; + return Msg.str(); +} + +bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, const MCInst &Inst) { + const MCOperand &Operand = Inst.getOperand(0); + if (!Operand.isImm()) + return false; + if (BrStack.size() <= static_cast(Operand.getImm())) + return typeError(ErrorLoc, StringRef("br: invalid depth ") + + std::to_string(Operand.getImm()) + " (max " + + std::to_string(BrStack.size()) + ")"); + const llvm::WebAssemblyAsmOperandFrame &Frame = + BrStack[BrStack.size() - Operand.getImm() - 1]; + SmallVector TargetStack; + switch (Frame.getType()) { + case WebAssemblyAsmOperandFrame::FrameType::Loop: + TargetStack = Frame.getEnterStack(); + break; + case WebAssemblyAsmOperandFrame::FrameType::Block: + case WebAssemblyAsmOperandFrame::FrameType::Try: + case WebAssemblyAsmOperandFrame::FrameType::If: + TargetStack = Frame.applySignature(); + break; + } + if (TargetStack != Stack) { + return typeError(ErrorLoc, + StringRef{"br "} + printStackDiff(TargetStack, Stack)); + } + return false; +} + +bool WebAssemblyAsmTypeCheck::checkBlockIn( + WebAssemblyAsmOperandFrame::FrameType Type, SMLoc ErrorLoc, + const MCInst &Inst) { + BrStack.push_back(WebAssemblyAsmOperandFrame{Type, Stack, LastSig}); + // check signature + for (wasm::ValType VT : llvm::reverse(LastSig.Params)) { + WebAssemblyAsmOperandFrame::OperandStack EnterStack = + BrStack.back().getEnterStack(); + if (popType(EnterStack, ErrorLoc, VT)) + return true; + } + return false; +} + bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) { if (LastSig.Returns.size() > Stack.size()) return typeError(ErrorLoc, "end: insufficient values on the type stack"); - + if (PopVals) { for (auto VT : llvm::reverse(LastSig.Returns)) { - if (popType(ErrorLoc, VT)) + if (popType(ErrorLoc, VT)) return true; } return false; } - + for (size_t i = 0; i < LastSig.Returns.size(); i++) { auto EVT = LastSig.Returns[i]; auto PVT = Stack[Stack.size() - LastSig.Returns.size() + i]; if (PVT != EVT) - return typeError( - ErrorLoc, StringRef("end got ") + WebAssembly::typeToString(PVT) + - ", expected " + WebAssembly::typeToString(EVT)); + return typeError(ErrorLoc, + StringRef("end got ") + WebAssembly::typeToString(PVT) + + ", expected " + WebAssembly::typeToString(EVT)); } + BrStack.pop_back(); return false; } bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc, - const wasm::WasmSignature& Sig) { + const wasm::WasmSignature &Sig) { for (auto VT : llvm::reverse(Sig.Params)) - if (popType(ErrorLoc, VT)) return true; + if (popType(ErrorLoc, VT)) + return true; Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end()); return false; } @@ -187,7 +278,7 @@ [[fallthrough]]; default: return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() + - " missing .globaltype"); + " missing .globaltype"); } return false; } @@ -275,19 +366,38 @@ } else if (Name == "drop") { if (popType(ErrorLoc, {})) return true; + } else if (Name == "loop") { + if (checkBlockIn(WebAssemblyAsmOperandFrame::FrameType::Loop, ErrorLoc, + Inst)) + return true; + } else if (Name == "block") { + if (checkBlockIn(WebAssemblyAsmOperandFrame::FrameType::Block, ErrorLoc, + Inst)) + return true; + } else if (Name == "try") { + if (checkBlockIn(WebAssemblyAsmOperandFrame::FrameType::Try, ErrorLoc, + Inst)) + return true; } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" || Name == "else" || Name == "end_try") { if (checkEnd(ErrorLoc, Name == "else")) return true; - if (Name == "end_block") + if (Name == "end_if" || Name == "else") Unreachable = false; + } else if (Name == "br") { + // TODO: br_if + if (checkBr(ErrorLoc, Inst)) + return true; + Unreachable = true; } else if (Name == "return") { if (endOfFunction(ErrorLoc)) return true; } else if (Name == "call_indirect" || Name == "return_call_indirect") { // Function value. - if (popType(ErrorLoc, wasm::ValType::I32)) return true; - if (checkSig(ErrorLoc, LastSig)) return true; + if (popType(ErrorLoc, wasm::ValType::I32)) + return true; + if (checkSig(ErrorLoc, LastSig)) + return true; if (Name == "return_call_indirect" && endOfFunction(ErrorLoc)) return true; } else if (Name == "call" || Name == "return_call") { @@ -300,10 +410,12 @@ return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") + WasmSym->getName() + " missing .functype"); - if (checkSig(ErrorLoc, *Sig)) return true; + if (checkSig(ErrorLoc, *Sig)) + return true; if (Name == "return_call" && endOfFunction(ErrorLoc)) return true; } else if (Name == "catch") { + Unreachable = false; const MCSymbolRefExpr *SymRef; if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef)) return true; @@ -313,10 +425,12 @@ return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") + WasmSym->getName() + " missing .tagtype"); - // catch instruction pushes values whose types are specified in the tag's - // "params" part + // if nest structure is good, last BrStack is operand stack in try instr; + Stack = BrStack.back().getEnterStack(); + // catch instruction pushes values whose types are specified in the + // tag's "params" part Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end()); - } else if (Name == "unreachable") { + } else if (Name == "unreachable" || Name == "throw") { Unreachable = true; } else if (Name == "ref.is_null") { if (popRefType(ErrorLoc)) @@ -345,6 +459,9 @@ auto VT = WebAssembly::regClassToValType(Op.RegClass); Stack.push_back(VT); } + if (Name == "if" && + checkBlockIn(WebAssemblyAsmOperandFrame::FrameType::If, ErrorLoc, Inst)) + return true; } return false; } diff --git a/llvm/test/MC/WebAssembly/type-checker-control-flow.s b/llvm/test/MC/WebAssembly/type-checker-control-flow.s new file mode 100644 --- /dev/null +++ b/llvm/test/MC/WebAssembly/type-checker-control-flow.s @@ -0,0 +1,82 @@ +# RUN: llvm-mc -triple=wasm32 -mattr=+exception-handling %s 2>&1 + +# Check type checker for control flow instructions + +br_block: + .functype br_block () -> (i32) + block i32 + i32.const 1 + br 0 + end_block + end_function +# CHECK-LABEL: br_block: +# CHECK-NEXT: .functype br_block () -> (i32) +# CHECK-NEXT: .local i32 +# CHECK-NEXT: block i32 +# CHECK-NEXT: i32.const 1 +# CHECK-NEXT: br 0 +# CHECK-NEXT: end_block +# CHECK-NEXT: end_function + +br_loop: + .functype br_loop () -> (i32) + loop i32 + br 0 + i32.const 2 + end_loop + end_function +# CHECK-LABEL: br_loop: +# CHECK-NEXT: .functype br_loop () -> (i32) +# CHECK-NEXT: loop i32 +# CHECK-NEXT: br 0 +# CHECK-NEXT: i32.const 2 +# CHECK-NEXT: end_loop +# CHECK-NEXT: end_function + +br_if_block: + .functype br_if_block () -> (i32) + i32.const 3 + if i32 + i32.const 4 + br 0 + else + i32.const 5 + br 0 + end_if + end_function +# CHECK-LABEL: br_if_block: +# CHECK-NEXT: .functype br_if_block () -> (i32) +# CHECK-NEXT: i32.const 3 +# CHECK-NEXT: if i32 +# CHECK-NEXT: i32.const 4 +# CHECK-NEXT: br 0 +# CHECK-NEXT: else +# CHECK-NEXT: i32.const 5 +# CHECK-NEXT: br 0 +# CHECK-NEXT: end_if +# CHECK-NEXT: end_function + +br_try: + .functype br_try () -> () + .tagtype tag_f32 f32 + try i32 + i32.const 1 + br 0 + catch tag_f32 + i32.trunc_f32_s + br 0 + end_try + drop + end_function +# CHECK-LABEL: br_try: +# CHECK-NEXT: .functype br_try () -> () +# CHECK-NEXT: .tagtype tag_f32 f32 +# CHECK-NEXT: try i32 +# CHECK-NEXT: i32.const 1 +# CHECK-NEXT: br 0 +# CHECK-NEXT: catch tag_f32 +# CHECK-NEXT: i32.trunc_f32_s +# CHECK-NEXT: br 0 +# CHECK-NEXT: end_try +# CHECK-NEXT: drop +# CHECK-NEXT: end_function \ No newline at end of file diff --git a/llvm/test/MC/WebAssembly/type-checker-errors.s b/llvm/test/MC/WebAssembly/type-checker-errors.s --- a/llvm/test/MC/WebAssembly/type-checker-errors.s +++ b/llvm/test/MC/WebAssembly/type-checker-errors.s @@ -513,3 +513,52 @@ f32.add # CHECK: :[[@LINE+1]]:3: error: 1 superfluous return values end_function + +br_invalid_type: + .functype br_invalid_type () -> (i32) + .local i32 + loop i32 + local.get 0 + if i32 + i32.const 0 +# CHECK: :[[@LINE+1]]:7: error: br got [..., i32], expected [...] + br 1 + else + i32.const 0 + end_if + end_loop + end_function + +br_invalid_depth: + .functype br_invalid_depth () -> (i32) + .local i32 + loop i32 + local.get 0 + if i32 +# CHECK: :[[@LINE+1]]:7: error: br: invalid depth 10 (max 2) + br 10 + else + i32.const 0 + end_if + end_loop + end_function + +invalid_block_params: + .functype invalid_block_params () -> (i32) +# CHECK: :[[@LINE+1]]:3: error: empty stack while popping i32 + block (i32) -> () + end_block + end_function + +br_try: + .functype br_try () -> () + .tagtype tag_f32 f32 + try i32 + i32.const 1 + br 0 + catch tag_f32 +# CHECK: :[[@LINE+1]]:5: error: br got [..., f32], expected [..., i32] + br 0 + end_try + drop + end_function