diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h --- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h +++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h @@ -16,14 +16,40 @@ #ifndef LLVM_LIB_TARGET_WEBASSEMBLY_ASMPARSER_TYPECHECK_H #define LLVM_LIB_TARGET_WEBASSEMBLY_ASMPARSER_TYPECHECK_H +#include "llvm/ADT/SmallVector.h" #include "llvm/BinaryFormat/Wasm.h" #include "llvm/MC/MCInstrInfo.h" #include "llvm/MC/MCParser/MCAsmParser.h" #include "llvm/MC/MCParser/MCTargetAsmParser.h" #include "llvm/MC/MCSymbol.h" +#include "llvm/Support/SMLoc.h" namespace llvm { +class WebAssemblyAsmOperandFrame final { +public: + using OperandStack = SmallVector; + enum class FrameType { + Loop, + Block, + Try, + If, + }; + WebAssemblyAsmOperandFrame(FrameType FrameType, + const SmallVectorImpl &Stack, + const wasm::WasmSignature &Signature) + : Type(FrameType), EnterStack(Stack.begin(), Stack.end()), + Sig(Signature) {} + const OperandStack &getEnterStack() const { return EnterStack; } + OperandStack getResultType() const; + OperandStack applySignature() const; + +private: + const FrameType Type; + OperandStack EnterStack; + wasm::WasmSignature Sig; +}; + class WebAssemblyAsmTypeCheck final { MCAsmParser &Parser; const MCInstrInfo &MII; @@ -31,6 +57,7 @@ SmallVector Stack; SmallVector LocalTypes; SmallVector ReturnTypes; + SmallVector BrStack; wasm::WasmSignature LastSig; bool TypeErrorThisFunction = false; bool Unreachable = false; @@ -39,9 +66,14 @@ void dumpTypeStack(Twine Msg); bool typeError(SMLoc ErrorLoc, const Twine &Msg); bool popType(SMLoc ErrorLoc, std::optional EVT); + bool popType(SmallVectorImpl &Stack, SMLoc ErrorLoc, + std::optional EVT); bool popRefType(SMLoc ErrorLoc); bool getLocal(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type); - bool checkEnd(SMLoc ErrorLoc, bool PopVals = false); + bool checkEnd(SMLoc ErrorLoc, bool RecoverEnter = false); + bool checkBr(SMLoc ErrorLoc, const MCInst &Inst); + bool checkBlockIn(WebAssemblyAsmOperandFrame::FrameType Type, SMLoc ErrorLoc, + const MCInst &Inst); bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig); bool getSymRef(SMLoc ErrorLoc, const MCInst &Inst, const MCSymbolRefExpr *&SymRef); @@ -49,7 +81,8 @@ bool getTable(SMLoc ErrorLoc, const MCInst &Inst, wasm::ValType &Type); public: - WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII, bool is64); + WebAssemblyAsmTypeCheck(MCAsmParser &Parser, const MCInstrInfo &MII, + bool is64); void funcDecl(const wasm::WasmSignature &Sig); void localDecl(const SmallVectorImpl &Locals); @@ -61,6 +94,7 @@ Stack.clear(); LocalTypes.clear(); ReturnTypes.clear(); + BrStack.clear(); TypeErrorThisFunction = false; Unreachable = false; } @@ -68,4 +102,4 @@ } // end namespace llvm -#endif // LLVM_LIB_TARGET_WEBASSEMBLY_ASMPARSER_TYPECHECK_H +#endif // LLVM_LIB_TARGET_WEBASSEMBLY_ASMPARSER_TYPECHECK_H diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp --- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp +++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp @@ -19,6 +19,9 @@ #include "MCTargetDesc/WebAssemblyTargetStreamer.h" #include "TargetInfo/WebAssemblyTargetInfo.h" #include "WebAssembly.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/BinaryFormat/Wasm.h" #include "llvm/MC/MCContext.h" #include "llvm/MC/MCExpr.h" #include "llvm/MC/MCInst.h" @@ -32,8 +35,13 @@ #include "llvm/MC/MCSymbolWasm.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/Endian.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/SourceMgr.h" +#include +#include +#include using namespace llvm; @@ -43,11 +51,33 @@ namespace llvm { -WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser, - const MCInstrInfo &MII, bool is64) - : Parser(Parser), MII(MII), is64(is64) { +WebAssemblyAsmOperandFrame::OperandStack +WebAssemblyAsmOperandFrame::getResultType() const { + switch (Type) { + case FrameType::Loop: + return OperandStack{Sig.Params.begin(), Sig.Params.end()}; + case FrameType::Block: + case FrameType::Try: + case FrameType::If: + return OperandStack{Sig.Returns.begin(), Sig.Returns.end()}; + ; + } + llvm_unreachable("invalid FrameType"); +} + +WebAssemblyAsmOperandFrame::OperandStack +WebAssemblyAsmOperandFrame::applySignature() const { + OperandStack Stack{EnterStack}; + Stack.pop_back_n(Sig.Params.size()); + Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end()); + return Stack; } +WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser, + const MCInstrInfo &MII, + bool is64) + : Parser(Parser), MII(MII), is64(is64) {} + void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) { LocalTypes.assign(Sig.Params.begin(), Sig.Params.end()); ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end()); @@ -84,13 +114,18 @@ bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc, std::optional EVT) { - if (Stack.empty()) { + return popType(Stack, ErrorLoc, EVT); +} +bool WebAssemblyAsmTypeCheck::popType(SmallVectorImpl &PopStack, + SMLoc ErrorLoc, + std::optional EVT) { + if (PopStack.empty()) { return typeError(ErrorLoc, EVT ? StringRef("empty stack while popping ") + WebAssembly::typeToString(*EVT) : StringRef("empty stack while popping value")); } - auto PVT = Stack.pop_back_val(); + auto PVT = PopStack.pop_back_val(); if (EVT && *EVT != PVT) { return typeError(ErrorLoc, StringRef("popped ") + WebAssembly::typeToString(PVT) + @@ -117,38 +152,92 @@ auto Local = static_cast(Inst.getOperand(0).getImm()); if (Local >= LocalTypes.size()) return typeError(ErrorLoc, StringRef("no local type specified for index ") + - std::to_string(Local)); + std::to_string(Local)); Type = LocalTypes[Local]; return false; } -bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) { - if (LastSig.Returns.size() > Stack.size()) - return typeError(ErrorLoc, "end: insufficient values on the type stack"); +static std::optional +checkStackTop(const SmallVectorImpl &ExpectedStackTop, + const SmallVectorImpl &Got) { + for (size_t I = 0; I < ExpectedStackTop.size(); I++) { + auto EVT = ExpectedStackTop[I]; + auto PVT = Got[Got.size() - ExpectedStackTop.size() + I]; + if (PVT != EVT) + return std::string{"got "} + WebAssembly::typeToString(PVT) + + ", expected " + WebAssembly::typeToString(EVT); + } + return std::nullopt; +} - if (PopVals) { - for (auto VT : llvm::reverse(LastSig.Returns)) { - if (popType(ErrorLoc, VT)) - return true; - } +bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, const MCInst &Inst) { + const MCOperand &Operand = Inst.getOperand(0); + if (!Operand.isImm()) return false; + int64_t BrLevel = Operand.getImm(); + if (BrStack.size() <= static_cast(BrLevel)) { + if (BrStack.empty()) + return typeError(ErrorLoc, StringRef("br: invalid depth ") + + std::to_string(BrLevel)); + return typeError(ErrorLoc, StringRef("br: invalid depth ") + + std::to_string(BrLevel) + " (max " + + std::to_string(BrStack.size() - 1) + ")"); } + const llvm::WebAssemblyAsmOperandFrame &Frame = + BrStack[BrStack.size() - BrLevel - 1]; + std::optional StackCheckResult = + checkStackTop(Frame.getResultType(), Stack); + if (StackCheckResult) + return typeError(ErrorLoc, + StringRef{"br "} + StringRef{StackCheckResult.value()}); + return false; +} - for (size_t i = 0; i < LastSig.Returns.size(); i++) { - auto EVT = LastSig.Returns[i]; - auto PVT = Stack[Stack.size() - LastSig.Returns.size() + i]; - if (PVT != EVT) - return typeError( - ErrorLoc, StringRef("end got ") + WebAssembly::typeToString(PVT) + - ", expected " + WebAssembly::typeToString(EVT)); +bool WebAssemblyAsmTypeCheck::checkBlockIn( + WebAssemblyAsmOperandFrame::FrameType Type, SMLoc ErrorLoc, + const MCInst &Inst) { + BrStack.push_back(WebAssemblyAsmOperandFrame{Type, Stack, LastSig}); + // Check the block input params + WebAssemblyAsmOperandFrame::OperandStack EnterStack = + BrStack.back().getEnterStack(); + for (wasm::ValType VT : llvm::reverse(LastSig.Params)) + if (popType(EnterStack, ErrorLoc, VT)) + return true; + return false; +} + +bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool RecoverEnter) { + if (LastSig.Returns.size() > Stack.size()) + return typeError(ErrorLoc, "end: insufficient values on the type stack"); + + std::optional StackCheckResult = + checkStackTop(LastSig.Returns, Stack); + + if (StackCheckResult) + return typeError(ErrorLoc, + StringRef{"end "} + StringRef{StackCheckResult.value()}); + + if (RecoverEnter) + // For nice nested `else`, BrStack.back() stores enter stack of `if`. + Stack = BrStack.back().getEnterStack(); + else { + // In some cases (unreachable / br), `Stack` will be invalid and need to be + // recovered. + if (Unreachable) + Stack = BrStack.pop_back_val().applySignature(); + else + BrStack.pop_back(); } + Unreachable = false; + return false; } bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc, - const wasm::WasmSignature& Sig) { + const wasm::WasmSignature &Sig) { for (auto VT : llvm::reverse(Sig.Params)) - if (popType(ErrorLoc, VT)) return true; + if (popType(ErrorLoc, VT)) + return true; Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end()); return false; } @@ -187,7 +276,7 @@ [[fallthrough]]; default: return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() + - " missing .globaltype"); + " missing .globaltype"); } return false; } @@ -299,19 +388,36 @@ } else if (Name == "drop") { if (popType(ErrorLoc, {})) return true; + } else if (Name == "loop") { + if (checkBlockIn(WebAssemblyAsmOperandFrame::FrameType::Loop, ErrorLoc, + Inst)) + return true; + } else if (Name == "block") { + if (checkBlockIn(WebAssemblyAsmOperandFrame::FrameType::Block, ErrorLoc, + Inst)) + return true; + } else if (Name == "try") { + if (checkBlockIn(WebAssemblyAsmOperandFrame::FrameType::Try, ErrorLoc, + Inst)) + return true; } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" || - Name == "else" || Name == "end_try") { + Name == "else" || Name == "end_try" || Name == "delegate") { if (checkEnd(ErrorLoc, Name == "else")) return true; - if (Name == "end_block") - Unreachable = false; + } else if (Name == "br") { + // TODO: br_if, br_table + if (checkBr(ErrorLoc, Inst)) + return true; + Unreachable = true; } else if (Name == "return") { if (endOfFunction(ErrorLoc)) return true; } else if (Name == "call_indirect" || Name == "return_call_indirect") { // Function value. - if (popType(ErrorLoc, wasm::ValType::I32)) return true; - if (checkSig(ErrorLoc, LastSig)) return true; + if (popType(ErrorLoc, wasm::ValType::I32)) + return true; + if (checkSig(ErrorLoc, LastSig)) + return true; if (Name == "return_call_indirect" && endOfFunction(ErrorLoc)) return true; } else if (Name == "call" || Name == "return_call") { @@ -324,10 +430,12 @@ return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") + WasmSym->getName() + " missing .functype"); - if (checkSig(ErrorLoc, *Sig)) return true; + if (checkSig(ErrorLoc, *Sig)) + return true; if (Name == "return_call" && endOfFunction(ErrorLoc)) return true; } else if (Name == "catch") { + Unreachable = false; const MCSymbolRefExpr *SymRef; if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef)) return true; @@ -337,10 +445,12 @@ return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") + WasmSym->getName() + " missing .tagtype"); - // catch instruction pushes values whose types are specified in the tag's - // "params" part + // if nest structure is good, last BrStack is operand stack in try instr; + Stack = BrStack.back().getEnterStack(); + // catch instruction pushes values whose types are specified in the + // tag's "params" part Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end()); - } else if (Name == "unreachable") { + } else if (Name == "unreachable" || Name == "throw") { Unreachable = true; } else if (Name == "ref.is_null") { if (popRefType(ErrorLoc)) @@ -369,6 +479,9 @@ auto VT = WebAssembly::regClassToValType(Op.RegClass); Stack.push_back(VT); } + if (Name == "if" && + checkBlockIn(WebAssemblyAsmOperandFrame::FrameType::If, ErrorLoc, Inst)) + return true; } return false; } diff --git a/llvm/test/MC/WebAssembly/basic-assembly.s b/llvm/test/MC/WebAssembly/basic-assembly.s --- a/llvm/test/MC/WebAssembly/basic-assembly.s +++ b/llvm/test/MC/WebAssembly/basic-assembly.s @@ -81,12 +81,15 @@ end_block # default jumps here. i32.const 3 end_block # "switch" exit. + drop + drop + i32.const 1 + if # void + i32.const 1 if # void - if i32 end_if else end_if - drop block void i32.const 2 return @@ -224,12 +227,15 @@ # CHECK-NEXT: end_block # label3: # CHECK-NEXT: i32.const 3 # CHECK-NEXT: end_block # label2: +# CHECK-NEXT: drop +# CHECK-NEXT: drop +# CHECK-NEXT: i32.const 1 +# CHECK-NEXT: if +# CHECK-NEXT: i32.const 1 # CHECK-NEXT: if -# CHECK-NEXT: if i32 # CHECK-NEXT: end_if # CHECK-NEXT: else # CHECK-NEXT: end_if -# CHECK-NEXT: drop # CHECK-NEXT: block # CHECK-NEXT: i32.const 2 # CHECK-NEXT: return diff --git a/llvm/test/MC/WebAssembly/type-checker-control-flow.s b/llvm/test/MC/WebAssembly/type-checker-control-flow.s new file mode 100644 --- /dev/null +++ b/llvm/test/MC/WebAssembly/type-checker-control-flow.s @@ -0,0 +1,131 @@ +# RUN: llvm-mc -triple=wasm32 -mattr=+exception-handling %s 2>&1 + +# Check type checker for control flow instructions + +br_block: + .functype br_block () -> (i32) + block i32 + i32.const 1 + br 0 + end_block + end_function +# CHECK-LABEL: br_block: +# CHECK-NEXT: .functype br_block () -> (i32) +# CHECK-NEXT: .local i32 +# CHECK-NEXT: block i32 +# CHECK-NEXT: i32.const 1 +# CHECK-NEXT: br 0 +# CHECK-NEXT: end_block +# CHECK-NEXT: end_function + +br_loop: + .functype br_loop () -> (i32) + loop i32 + br 0 + i32.const 2 + end_loop + end_function +# CHECK-LABEL: br_loop: +# CHECK-NEXT: .functype br_loop () -> (i32) +# CHECK-NEXT: loop i32 +# CHECK-NEXT: br 0 +# CHECK-NEXT: i32.const 2 +# CHECK-NEXT: end_loop +# CHECK-NEXT: end_function + +br_if_block: + .functype br_if_block () -> (i32) + i32.const 3 + if i32 + i32.const 4 + br 0 + else + i32.const 5 + br 0 + end_if + end_function +# CHECK-LABEL: br_if_block: +# CHECK-NEXT: .functype br_if_block () -> (i32) +# CHECK-NEXT: i32.const 3 +# CHECK-NEXT: if i32 +# CHECK-NEXT: i32.const 4 +# CHECK-NEXT: br 0 +# CHECK-NEXT: else +# CHECK-NEXT: i32.const 5 +# CHECK-NEXT: br 0 +# CHECK-NEXT: end_if +# CHECK-NEXT: end_function + +br_try: + .functype br_try () -> () + .tagtype tag_f32 f32 + try i32 + i32.const 1 + br 0 + catch tag_f32 + i32.trunc_f32_s + br 0 + end_try + drop + end_function +# CHECK-LABEL: br_try: +# CHECK-NEXT: .functype br_try () -> () +# CHECK-NEXT: .tagtype tag_f32 f32 +# CHECK-NEXT: try i32 +# CHECK-NEXT: i32.const 1 +# CHECK-NEXT: br 0 +# CHECK-NEXT: catch tag_f32 +# CHECK-NEXT: i32.trunc_f32_s +# CHECK-NEXT: br 0 +# CHECK-NEXT: end_try +# CHECK-NEXT: drop +# CHECK-NEXT: end_function + +delegate: + .functype delegate () -> () + .tagtype tag_i32 i32 + try i32 + try f32 + f32.const 1.0 + delegate 0 + i32.trunc_f32_s + br 0 + catch tag_i32 + end_try + drop + end_function +# CHECK-LABEL: delegate: +# CHECK-NEXT: .functype delegate () -> () +# CHECK-NEXT: .tagtype tag_i32 i32 +# CHECK-NEXT: try i32 +# CHECK-NEXT: try f32 +# CHECK-NEXT: f32.const 1.0 +# CHECK-NEXT: delegate 0 +# CHECK-NEXT: i32.trunc_f32_s +# CHECK-NEXT: br 0 +# CHECK-NEXT: catch tag_i32 +# CHECK-NEXT: end_try +# CHECK-NEXT: drop +# CHECK-NEXT: end_function +# CHECK-LABEL: delegate: + +recove_stack_after_br: + .functype recove_stack_after_br () -> (i32) + block i32 + block i32 + i32.const 1 + br 0 + i32.const 1 + end_block + end_block + end_function +# CHECK-LABEL: recove_stack_after_br: +# CHECK-NEXT: .functype recove_stack_after_br () -> (i32) +# CHECK-NEXT: .local i32 +# CHECK-NEXT: block i32 +# CHECK-NEXT: block i32 +# CHECK-NEXT: i32.const 1 +# CHECK-NEXT: br 0 +# CHECK-NEXT: end_block +# CHECK-NEXT: end_block +# CHECK-NEXT: end_function diff --git a/llvm/test/MC/WebAssembly/type-checker-errors.s b/llvm/test/MC/WebAssembly/type-checker-errors.s --- a/llvm/test/MC/WebAssembly/type-checker-errors.s +++ b/llvm/test/MC/WebAssembly/type-checker-errors.s @@ -513,3 +513,52 @@ f32.add # CHECK: :[[@LINE+1]]:3: error: 1 superfluous return values end_function + +br_invalid_type: + .functype br_invalid_type () -> (i32) + .local i32 + local.get 0 + loop (i32) -> (i32) + if i32 + f32.const 0.0 +# CHECK: :[[@LINE+1]]:7: error: br got f32, expected i32 + br 1 + else + i32.const 0 + end_if + end_loop + end_function + +br_invalid_depth: + .functype br_invalid_depth () -> (i32) + .local i32 + loop i32 + local.get 0 + if i32 +# CHECK: :[[@LINE+1]]:7: error: br: invalid depth 10 (max 2) + br 10 + else + i32.const 0 + end_if + end_loop + end_function + +invalid_block_params: + .functype invalid_block_params () -> (i32) +# CHECK: :[[@LINE+1]]:3: error: empty stack while popping i32 + block (i32) -> () + end_block + end_function + +br_try: + .functype br_try () -> () + .tagtype tag_f32 f32 + try i32 + i32.const 1 + br 0 + catch tag_f32 +# CHECK: :[[@LINE+1]]:5: error: br got f32, expected i32 + br 0 + end_try + drop + end_function