diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyExceptionInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyExceptionInfo.h --- a/llvm/lib/Target/WebAssembly/WebAssemblyExceptionInfo.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyExceptionInfo.h @@ -45,7 +45,7 @@ WebAssemblyException *ParentException = nullptr; std::vector> SubExceptions; std::vector Blocks; - SmallPtrSet BlockSet; + SmallPtrSet BlockSet; public: WebAssemblyException(MachineBasicBlock *EHPad) : EHPad(EHPad) {} @@ -68,9 +68,12 @@ return BlockSet.count(MBB); } + void addToBlocksSet(MachineBasicBlock *MBB) { BlockSet.insert(MBB); } + void removeFromBlocksSet(MachineBasicBlock *MBB) { BlockSet.erase(MBB); } + void addToBlocksVector(MachineBasicBlock *MBB) { Blocks.push_back(MBB); } void addBlock(MachineBasicBlock *MBB) { - Blocks.push_back(MBB); BlockSet.insert(MBB); + Blocks.push_back(MBB); } ArrayRef getBlocks() const { return Blocks; } using block_iterator = typename ArrayRef::const_iterator; @@ -81,8 +84,10 @@ } unsigned getNumBlocks() const { return Blocks.size(); } std::vector &getBlocksVector() { return Blocks; } + SmallPtrSetImpl &getBlocksSet() { return BlockSet; } - const std::vector> &getSubExceptions() const { + const std::vector> & + getSubExceptions() const { return SubExceptions; } std::vector> &getSubExceptions() { @@ -149,7 +154,8 @@ return BBMap.lookup(MBB); } - void changeExceptionFor(MachineBasicBlock *MBB, WebAssemblyException *WE) { + void changeExceptionFor(const MachineBasicBlock *MBB, + WebAssemblyException *WE) { if (!WE) { BBMap.erase(MBB); return; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyExceptionInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyExceptionInfo.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyExceptionInfo.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyExceptionInfo.cpp @@ -19,6 +19,8 @@ #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/WasmEHFuncInfo.h" #include "llvm/InitializePasses.h" +#include "llvm/MC/MCAsmInfo.h" +#include "llvm/Target/TargetMachine.h" using namespace llvm; @@ -38,9 +40,37 @@ "********** Function: " << MF.getName() << '\n'); releaseMemory(); + if (MF.getTarget().getMCAsmInfo()->getExceptionHandlingType() != + ExceptionHandling::Wasm || + !MF.getFunction().hasPersonalityFn()) + return false; auto &MDT = getAnalysis(); auto &MDF = getAnalysis(); recalculate(MF, MDT, MDF); + LLVM_DEBUG(dump()); + return false; +} + +// Check if Dst is reachable from Src using BFS. Search only within BBs +// dominated by Header. +static bool isReachableAmongDominated(const MachineBasicBlock *Src, + const MachineBasicBlock *Dst, + const MachineBasicBlock *Header, + const MachineDominatorTree &MDT) { + assert(MDT.dominates(Header, Dst)); + SmallVector WL; + SmallPtrSet Visited; + WL.push_back(Src); + + while (!WL.empty()) { + const auto *MBB = WL.pop_back_val(); + if (MBB == Dst) + return true; + Visited.insert(MBB); + for (auto *Succ : MBB->successors()) + if (!Visited.count(Succ) && MDT.dominates(Header, Succ)) + WL.push_back(Succ); + } return false; } @@ -83,30 +113,103 @@ // Here we extract those unwind destinations from their (incorrect) parent // exception. Note that the unwind destinations may not be an immediate // children of the parent exception, so we have to traverse the parent chain. + // + // We should traverse BBs in the preorder of the dominator tree, because + // otherwise the result can be incorrect. For example, when there are three + // exceptions A, B, and C and A > B > C (> is subexception relationship here), + // and A's unwind destination is B and B's is C. When we visit B before A, we + // end up extracting C only out of B but not out of A. const auto *EHInfo = MF.getWasmEHFuncInfo(); - for (auto &MBB : MF) { - if (!MBB.isEHPad()) + DenseMap UnwindWEMap; + for (auto *DomNode : depth_first(&MDT)) { + MachineBasicBlock *EHPad = DomNode->getBlock(); + if (!EHPad->isEHPad()) continue; - auto *EHPad = &MBB; if (!EHInfo->hasUnwindDest(EHPad)) continue; auto *UnwindDest = EHInfo->getUnwindDest(EHPad); auto *WE = getExceptionFor(EHPad); - auto *UnwindDestWE = getExceptionFor(UnwindDest); - if (WE->contains(UnwindDestWE)) { + auto *UnwindWE = getExceptionFor(UnwindDest); + if (WE->contains(UnwindWE)) { + UnwindWEMap[WE] = UnwindWE; + LLVM_DEBUG(dbgs() << "ExceptionInfo fix: " << WE->getEHPad()->getNumber() + << "." << WE->getEHPad()->getName() + << "'s exception is taken out of " + << UnwindWE->getEHPad()->getNumber() << "." + << UnwindWE->getEHPad()->getName() << "'s exception\n"); if (WE->getParentException()) - UnwindDestWE->setParentException(WE->getParentException()); + UnwindWE->setParentException(WE->getParentException()); else - UnwindDestWE->setParentException(nullptr); + UnwindWE->setParentException(nullptr); + } + } + + // Add BBs to exceptions' block set first + for (auto *DomNode : post_order(&MDT)) { + MachineBasicBlock *MBB = DomNode->getBlock(); + WebAssemblyException *WE = getExceptionFor(MBB); + for (; WE; WE = WE->getParentException()) + WE->addToBlocksSet(MBB); + } + + // After fixing subexception relationship between unwind destinations above, + // there can still be remaining discrepancies. + // + // For example, suppose Exception A is dominated by EHPad A and Exception B is + // dominated by EHPad B. EHPad A's unwind destination is EHPad B, but because + // EHPad B is dominated by EHPad A, the initial grouping makes Exception B a + // subexception of Exception A, and we fix it by taking Exception B out of + // Exception A above. But there can still be remaining BBs within Exception A + // that are reachable from Exception B. These BBs semantically doesn't belong + // to Exception A and were not a part of 'catch' clause or cleanup code in the + // original code, but they just happened to be grouped within Exception A + // because they were dominated by EHPad A. We fix this case by taking those + // BBs out of the incorrect exception and its all subexceptions that it + // belongs to. + for (auto &KV : UnwindWEMap) { + WebAssemblyException *WE = KV.first; + WebAssemblyException *UnwindWE = KV.second; + + for (auto *MBB : WE->getBlocksSet()) { + if (MBB->isEHPad()) { + // If this assertion is triggered, it would be a violation of scoping + // rules in ll files, because this means an instruction in an outer + // scope tries to unwind to an EH pad in an inner scope. + assert(!isReachableAmongDominated(UnwindWE->getEHPad(), MBB, + WE->getEHPad(), MDT) && + "Outer scope unwinds to inner scope. Bug in scope rules?"); + continue; + } + if (isReachableAmongDominated(UnwindWE->getEHPad(), MBB, WE->getEHPad(), + MDT)) { + LLVM_DEBUG(dbgs() << "Remainder BB: " << MBB->getNumber() << "." + << MBB->getName() << " is "); + WebAssemblyException *InnerWE = getExceptionFor(MBB); + while (InnerWE != WE) { + LLVM_DEBUG(dbgs() + << " removed from " << InnerWE->getEHPad()->getNumber() + << "." << InnerWE->getEHPad()->getName() + << "'s exception\n"); + InnerWE->removeFromBlocksSet(MBB); + InnerWE = InnerWE->getParentException(); + } + WE->removeFromBlocksSet(MBB); + LLVM_DEBUG(dbgs() << " removed from " << WE->getEHPad()->getNumber() + << "." << WE->getEHPad()->getName() + << "'s exception\n"); + changeExceptionFor(MBB, WE->getParentException()); + if (WE->getParentException()) + WE->getParentException()->addToBlocksSet(MBB); + } } } - // Add BBs to exceptions + // Add BBs to exceptions' block vector for (auto DomNode : post_order(&MDT)) { MachineBasicBlock *MBB = DomNode->getBlock(); WebAssemblyException *WE = getExceptionFor(MBB); for (; WE; WE = WE->getParentException()) - WE->addBlock(MBB); + WE->addToBlocksVector(MBB); } SmallVector ExceptionPointers; diff --git a/llvm/test/CodeGen/WebAssembly/cfg-stackify-eh.ll b/llvm/test/CodeGen/WebAssembly/cfg-stackify-eh.ll --- a/llvm/test/CodeGen/WebAssembly/cfg-stackify-eh.ll +++ b/llvm/test/CodeGen/WebAssembly/cfg-stackify-eh.ll @@ -1228,7 +1228,7 @@ br i1 undef, label %if.then, label %if.end12 if.then: ; preds = %entry - invoke void @__cxa_throw() #1 + invoke void @__cxa_throw(i8* null, i8* null, i8* null) #1 to label %unreachable unwind label %catch.dispatch catch.dispatch: ; preds = %if.then @@ -1245,7 +1245,7 @@ to label %invoke.cont unwind label %catch.dispatch4 invoke.cont: ; preds = %catchret.dest - invoke void @__cxa_throw() #1 + invoke void @__cxa_throw(i8* null, i8* null, i8* null) #1 to label %unreachable unwind label %catch.dispatch4 catch.dispatch4: ; preds = %invoke.cont, %catchret.dest @@ -1294,7 +1294,7 @@ %12 = cleanuppad within none [] cleanupret from %12 unwind to caller -unreachable: ; preds = %if.then, %invoke.cont, %rethrow19 +unreachable: ; preds = %rethrow19, %invoke.cont, %if.then unreachable } @@ -1313,7 +1313,7 @@ %1 = catchpad within %0 [i8* bitcast (i8** @_ZTIi to i8*)] %2 = call i8* @llvm.wasm.get.exception(token %1) %3 = call i32 @llvm.wasm.get.ehselector(token %1) - invoke void @__cxa_throw() #1 [ "funclet"(token %1) ] + invoke void @__cxa_throw(i8* null, i8* null, i8* null) #1 [ "funclet"(token %1) ] to label %unreachable unwind label %catch.dispatch2 catch.dispatch2: ; preds = %catch.start @@ -1330,7 +1330,7 @@ to label %invoke.cont8 unwind label %ehcleanup invoke.cont8: ; preds = %try.cont - invoke void @__cxa_throw() #1 [ "funclet"(token %1) ] + invoke void @__cxa_throw(i8* null, i8* null, i8* null) #1 [ "funclet"(token %1) ] to label %unreachable unwind label %catch.dispatch11 catch.dispatch11: ; preds = %invoke.cont8 @@ -1345,7 +1345,7 @@ invoke.cont: ; preds = %entry unreachable -ehcleanup: ; preds = %try.cont, %catch.dispatch11, %catch.dispatch2 +ehcleanup: ; preds = %catch.dispatch11, %try.cont, %catch.dispatch2 %12 = cleanuppad within %1 [] cleanupret from %12 unwind label %ehcleanup22 @@ -1353,12 +1353,196 @@ %13 = cleanuppad within none [] cleanupret from %13 unwind to caller -unreachable: ; preds = %catch.start, %invoke.cont8 +unreachable: ; preds = %invoke.cont8, %catch.start + unreachable +} + +; void test23() { +; try { +; try { +; throw 0; +; } catch (int) { +; } +; } catch (int) { +; } +; } +; +; Regression test for a WebAssemblyException grouping bug. After catchswitches +; are removed, EH pad catch.start2 is dominated by catch.start, but because +; catch.start2 is the unwind destination of catch.start, it should not be +; included in catch.start's exception. Also, after we take catch.start2's +; exception out of catch.start's exception, we have to take out try.cont8 out of +; catch.start's exception, because it has a predecessor in catch.start2. +define void @test23() personality i8* bitcast (i32 (...)* @__gxx_wasm_personality_v0 to i8*) { +entry: + %exception = call i8* @__cxa_allocate_exception(i32 4) #0 + %0 = bitcast i8* %exception to i32* + store i32 0, i32* %0, align 16 + invoke void @__cxa_throw(i8* %exception, i8* bitcast (i8** @_ZTIi to i8*), i8* null) #1 + to label %unreachable unwind label %catch.dispatch + +catch.dispatch: ; preds = %entry + %1 = catchswitch within none [label %catch.start] unwind label %catch.dispatch1 + +catch.start: ; preds = %catch.dispatch + %2 = catchpad within %1 [i8* bitcast (i8** @_ZTIi to i8*)] + %3 = call i8* @llvm.wasm.get.exception(token %2) + %4 = call i32 @llvm.wasm.get.ehselector(token %2) + %5 = call i32 @llvm.eh.typeid.for(i8* bitcast (i8** @_ZTIi to i8*)) #0 + %matches = icmp eq i32 %4, %5 + br i1 %matches, label %catch, label %rethrow + +catch: ; preds = %catch.start + %6 = call i8* @__cxa_begin_catch(i8* %3) #0 [ "funclet"(token %2) ] + %7 = bitcast i8* %6 to i32* + %8 = load i32, i32* %7, align 4 + call void @__cxa_end_catch() #0 [ "funclet"(token %2) ] + catchret from %2 to label %catchret.dest + +catchret.dest: ; preds = %catch + br label %try.cont + +rethrow: ; preds = %catch.start + invoke void @llvm.wasm.rethrow() #1 [ "funclet"(token %2) ] + to label %unreachable unwind label %catch.dispatch1 + +catch.dispatch1: ; preds = %rethrow, %catch.dispatch + %9 = catchswitch within none [label %catch.start2] unwind to caller + +catch.start2: ; preds = %catch.dispatch1 + %10 = catchpad within %9 [i8* bitcast (i8** @_ZTIi to i8*)] + %11 = call i8* @llvm.wasm.get.exception(token %10) + %12 = call i32 @llvm.wasm.get.ehselector(token %10) + %13 = call i32 @llvm.eh.typeid.for(i8* bitcast (i8** @_ZTIi to i8*)) #0 + %matches3 = icmp eq i32 %12, %13 + br i1 %matches3, label %catch5, label %rethrow4 + +catch5: ; preds = %catch.start2 + %14 = call i8* @__cxa_begin_catch(i8* %11) #0 [ "funclet"(token %10) ] + %15 = bitcast i8* %14 to i32* + %16 = load i32, i32* %15, align 4 + call void @__cxa_end_catch() #0 [ "funclet"(token %10) ] + catchret from %10 to label %catchret.dest7 + +catchret.dest7: ; preds = %catch5 + br label %try.cont8 + +rethrow4: ; preds = %catch.start2 + call void @llvm.wasm.rethrow() #1 [ "funclet"(token %10) ] + unreachable + +try.cont8: ; preds = %try.cont, %catchret.dest7 + ret void + +try.cont: ; preds = %catchret.dest + br label %try.cont8 + +unreachable: ; preds = %rethrow, %entry + unreachable +} + +; Test for WebAssemblyException grouping. This test is hand-modified to generate +; this structure: +; catch.start dominates catch.start4 and catch.start4 dominates catch.start12, +; so the after dominator-based grouping, we end up with: +; catch.start's exception > catch4.start's exception > catch12.start's exception +; (> here represents subexception relationship) +; +; But the unwind destination chain is catch.start -> catch.start4 -> +; catch.start12. So all these subexception relationship should be deconstructed. +; We have to make sure to take out catch.start4's exception out of catch.start's +; exception first, before taking out catch.start12's exception out of +; catch.start4's exception; otherwise we end up with an incorrect relationship +; of catch.start's exception > catch.start12's exception. +define void @test24() personality i8* bitcast (i32 (...)* +@__gxx_wasm_personality_v0 to i8*) { +entry: + invoke void @foo() + to label %invoke.cont unwind label %catch.dispatch + +invoke.cont: ; preds = %entry + invoke void @foo() + to label %invoke.cont1 unwind label %catch.dispatch + +invoke.cont1: ; preds = %invoke.cont + invoke void @foo() + to label %try.cont18 unwind label %catch.dispatch + +catch.dispatch11: ; preds = %rethrow6, %catch.dispatch3 + %0 = catchswitch within none [label %catch.start12] unwind to caller + +catch.start12: ; preds = %catch.dispatch11 + %1 = catchpad within %0 [i8* bitcast (i8** @_ZTIi to i8*)] + %2 = call i8* @llvm.wasm.get.exception(token %1) + %3 = call i32 @llvm.wasm.get.ehselector(token %1) + %4 = call i32 @llvm.eh.typeid.for(i8* bitcast (i8** @_ZTIi to i8*)) #0 + %matches13 = icmp eq i32 %3, %4 + br i1 %matches13, label %catch15, label %rethrow14 + +catch15: ; preds = %catch.start12 + %5 = call i8* @__cxa_begin_catch(i8* %2) #0 [ "funclet"(token %1) ] + %6 = bitcast i8* %5 to i32* + %7 = load i32, i32* %6, align 4 + call void @__cxa_end_catch() #0 [ "funclet"(token %1) ] + catchret from %1 to label %try.cont18 + +rethrow14: ; preds = %catch.start12 + call void @llvm.wasm.rethrow() #1 [ "funclet"(token %1) ] + unreachable + +catch.dispatch3: ; preds = %rethrow, %catch.dispatch + %8 = catchswitch within none [label %catch.start4] unwind label %catch.dispatch11 + +catch.start4: ; preds = %catch.dispatch3 + %9 = catchpad within %8 [i8* bitcast (i8** @_ZTIi to i8*)] + %10 = call i8* @llvm.wasm.get.exception(token %9) + %11 = call i32 @llvm.wasm.get.ehselector(token %9) + %12 = call i32 @llvm.eh.typeid.for(i8* bitcast (i8** @_ZTIi to i8*)) #0 + %matches5 = icmp eq i32 %11, %12 + br i1 %matches5, label %catch7, label %rethrow6 + +catch7: ; preds = %catch.start4 + %13 = call i8* @__cxa_begin_catch(i8* %10) #0 [ "funclet"(token %9) ] + %14 = bitcast i8* %13 to i32* + %15 = load i32, i32* %14, align 4 + call void @__cxa_end_catch() #0 [ "funclet"(token %9) ] + catchret from %9 to label %try.cont18 + +rethrow6: ; preds = %catch.start4 + invoke void @llvm.wasm.rethrow() #1 [ "funclet"(token %9) ] + to label %unreachable unwind label %catch.dispatch11 + +catch.dispatch: ; preds = %invoke.cont1, %invoke.cont, %entry + %16 = catchswitch within none [label %catch.start] unwind label %catch.dispatch3 + +catch.start: ; preds = %catch.dispatch + %17 = catchpad within %16 [i8* bitcast (i8** @_ZTIi to i8*)] + %18 = call i8* @llvm.wasm.get.exception(token %17) + %19 = call i32 @llvm.wasm.get.ehselector(token %17) + %20 = call i32 @llvm.eh.typeid.for(i8* bitcast (i8** @_ZTIi to i8*)) #0 + %matches = icmp eq i32 %19, %20 + br i1 %matches, label %catch, label %rethrow + +catch: ; preds = %catch.start + %21 = call i8* @__cxa_begin_catch(i8* %18) #0 [ "funclet"(token %17) ] + %22 = bitcast i8* %21 to i32* + %23 = load i32, i32* %22, align 4 + call void @__cxa_end_catch() #0 [ "funclet"(token %17) ] + catchret from %17 to label %try.cont18 + +rethrow: ; preds = %catch.start + invoke void @llvm.wasm.rethrow() #1 [ "funclet"(token %17) ] + to label %unreachable unwind label %catch.dispatch3 + +try.cont18: ; preds = %catch, %catch7, %catch15, %invoke.cont1 + ret void + +unreachable: ; preds = %rethrow, %rethrow6 unreachable } ; Check if the unwind destination mismatch stats are correct -; NOSORT: 20 wasm-cfg-stackify - Number of call unwind mismatches found +; NOSORT: 23 wasm-cfg-stackify - Number of call unwind mismatches found ; NOSORT: 4 wasm-cfg-stackify - Number of catch unwind mismatches found declare void @foo() @@ -1385,8 +1569,9 @@ declare i8* @llvm.wasm.get.exception(token) #0 ; Function Attrs: nounwind declare i32 @llvm.wasm.get.ehselector(token) #0 +declare i8* @__cxa_allocate_exception(i32) #0 ; Function Attrs: noreturn -declare void @__cxa_throw() #1 +declare void @__cxa_throw(i8*, i8*, i8*) #1 ; Function Attrs: noreturn declare void @llvm.wasm.rethrow() #1 ; Function Attrs: nounwind