diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -48,6 +48,7 @@ WebAssemblyRuntimeLibcallSignatures.cpp WebAssemblySelectionDAGInfo.cpp WebAssemblySetP2AlignOperands.cpp + WebAssemblySortRegion.cpp WebAssemblyMemIntrinsicResults.cpp WebAssemblySubtarget.cpp WebAssemblyTargetMachine.cpp diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyCFGSort.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyCFGSort.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyCFGSort.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyCFGSort.cpp @@ -19,6 +19,7 @@ #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" #include "WebAssembly.h" #include "WebAssemblyExceptionInfo.h" +#include "WebAssemblySortRegion.h" #include "WebAssemblySubtarget.h" #include "WebAssemblyUtilities.h" #include "llvm/ADT/PriorityQueue.h" @@ -31,6 +32,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; +using WebAssembly::SortRegion; +using WebAssembly::SortRegionInfo; #define DEBUG_TYPE "wasm-cfg-sort" @@ -44,78 +47,6 @@ namespace { -// Wrapper for loops and exceptions -class Region { -public: - virtual ~Region() = default; - virtual MachineBasicBlock *getHeader() const = 0; - virtual bool contains(const MachineBasicBlock *MBB) const = 0; - virtual unsigned getNumBlocks() const = 0; - using block_iterator = typename ArrayRef::const_iterator; - virtual iterator_range blocks() const = 0; - virtual bool isLoop() const = 0; -}; - -template class ConcreteRegion : public Region { - const T *Region; - -public: - ConcreteRegion(const T *Region) : Region(Region) {} - MachineBasicBlock *getHeader() const override { return Region->getHeader(); } - bool contains(const MachineBasicBlock *MBB) const override { - return Region->contains(MBB); - } - unsigned getNumBlocks() const override { return Region->getNumBlocks(); } - iterator_range blocks() const override { - return Region->blocks(); - } - bool isLoop() const override { return false; } -}; - -template <> bool ConcreteRegion::isLoop() const { return true; } - -// This class has information of nested Regions; this is analogous to what -// LoopInfo is for loops. -class RegionInfo { - const MachineLoopInfo &MLI; - const WebAssemblyExceptionInfo &WEI; - DenseMap> LoopMap; - DenseMap> ExceptionMap; - -public: - RegionInfo(const MachineLoopInfo &MLI, const WebAssemblyExceptionInfo &WEI) - : MLI(MLI), WEI(WEI) {} - - // Returns a smallest loop or exception that contains MBB - const Region *getRegionFor(const MachineBasicBlock *MBB) { - const auto *ML = MLI.getLoopFor(MBB); - const auto *WE = WEI.getExceptionFor(MBB); - if (!ML && !WE) - return nullptr; - // We determine subregion relationship by domination of their headers, i.e., - // if region A's header dominates region B's header, B is a subregion of A. - // WebAssemblyException contains BBs in all its subregions (loops or - // exceptions), but MachineLoop may not, because MachineLoop does not contain - // BBs that don't have a path to its header even if they are dominated by - // its header. So here we should use WE->contains(ML->getHeader()), but not - // ML->contains(WE->getHeader()). - if ((ML && !WE) || (ML && WE && WE->contains(ML->getHeader()))) { - // If the smallest region containing MBB is a loop - if (LoopMap.count(ML)) - return LoopMap[ML].get(); - LoopMap[ML] = std::make_unique>(ML); - return LoopMap[ML].get(); - } else { - // If the smallest region containing MBB is an exception - if (ExceptionMap.count(WE)) - return ExceptionMap[WE].get(); - ExceptionMap[WE] = - std::make_unique>(WE); - return ExceptionMap[WE].get(); - } - } -}; - class WebAssemblyCFGSort final : public MachineFunctionPass { StringRef getPassName() const override { return "WebAssembly CFG Sort"; } @@ -236,14 +167,14 @@ /// Bookkeeping for a region to help ensure that we don't mix blocks not /// dominated by the its header among its blocks. struct Entry { - const Region *TheRegion; + const SortRegion *TheRegion; unsigned NumBlocksLeft; /// List of blocks not dominated by Loop's header that are deferred until /// after all of Loop's blocks have been seen. std::vector Deferred; - explicit Entry(const class Region *R) + explicit Entry(const SortRegion *R) : TheRegion(R), NumBlocksLeft(R->getNumBlocks()) {} }; } // end anonymous namespace @@ -287,10 +218,10 @@ CompareBlockNumbersBackwards> Ready; - RegionInfo RI(MLI, WEI); + SortRegionInfo SRI(MLI, WEI); SmallVector Entries; for (MachineBasicBlock *MBB = &MF.front();;) { - const Region *R = RI.getRegionFor(MBB); + const SortRegion *R = SRI.getRegionFor(MBB); if (R) { // If MBB is a region header, add it to the active region list. We can't // put any blocks that it doesn't dominate until we see the end of the @@ -373,7 +304,7 @@ MF.RenumberBlocks(); #ifndef NDEBUG - SmallSetVector OnStack; + SmallSetVector OnStack; // Insert a sentinel representing the degenerate loop that starts at the // function entry block and includes the entire function as a "loop" that @@ -382,7 +313,7 @@ for (auto &MBB : MF) { assert(MBB.getNumber() >= 0 && "Renumbered blocks should be non-negative."); - const Region *Region = RI.getRegionFor(&MBB); + const SortRegion *Region = SRI.getRegionFor(&MBB); if (Region && &MBB == Region->getHeader()) { // Region header. @@ -408,10 +339,10 @@ for (auto Pred : MBB.predecessors()) assert(Pred->getNumber() < MBB.getNumber() && "Non-loop-header predecessors should be topologically sorted"); - assert(OnStack.count(RI.getRegionFor(&MBB)) && + assert(OnStack.count(SRI.getRegionFor(&MBB)) && "Blocks must be nested in their regions"); } - while (OnStack.size() > 1 && &MBB == WebAssembly::getBottom(OnStack.back())) + while (OnStack.size() > 1 && &MBB == SRI.getBottom(OnStack.back())) OnStack.pop_back(); } assert(OnStack.pop_back_val() == nullptr && diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp @@ -24,6 +24,7 @@ #include "WebAssembly.h" #include "WebAssemblyExceptionInfo.h" #include "WebAssemblyMachineFunctionInfo.h" +#include "WebAssemblySortRegion.h" #include "WebAssemblySubtarget.h" #include "WebAssemblyUtilities.h" #include "llvm/ADT/Statistic.h" @@ -33,6 +34,7 @@ #include "llvm/MC/MCAsmInfo.h" #include "llvm/Target/TargetMachine.h" using namespace llvm; +using WebAssembly::SortRegionInfo; #define DEBUG_TYPE "wasm-cfg-stackify" @@ -382,6 +384,8 @@ void WebAssemblyCFGStackify::placeLoopMarker(MachineBasicBlock &MBB) { MachineFunction &MF = *MBB.getParent(); const auto &MLI = getAnalysis(); + const auto &WEI = getAnalysis(); + SortRegionInfo SRI(MLI, WEI); const auto &TII = *MF.getSubtarget().getInstrInfo(); MachineLoop *Loop = MLI.getLoopFor(&MBB); @@ -390,7 +394,7 @@ // The operand of a LOOP is the first block after the loop. If the loop is the // bottom of the function, insert a dummy block at the end. - MachineBasicBlock *Bottom = WebAssembly::getBottom(Loop); + MachineBasicBlock *Bottom = SRI.getBottom(Loop); auto Iter = std::next(Bottom->getIterator()); if (Iter == MF.end()) { getAppendixBlock(MF); @@ -450,7 +454,9 @@ MachineFunction &MF = *MBB.getParent(); auto &MDT = getAnalysis(); const auto &TII = *MF.getSubtarget().getInstrInfo(); + const auto &MLI = getAnalysis(); const auto &WEI = getAnalysis(); + SortRegionInfo SRI(MLI, WEI); const auto &MFI = *MF.getInfo(); // Compute the nearest common dominator of all unwind predecessors @@ -470,7 +476,7 @@ // end. WebAssemblyException *WE = WEI.getExceptionFor(&MBB); assert(WE); - MachineBasicBlock *Bottom = WebAssembly::getBottom(WE); + MachineBasicBlock *Bottom = SRI.getBottom(WE); auto Iter = std::next(Bottom->getIterator()); if (Iter == MF.end()) { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySortRegion.h b/llvm/lib/Target/WebAssembly/WebAssemblySortRegion.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblySortRegion.h @@ -0,0 +1,91 @@ +//===-- WebAssemblySortRegion.h - WebAssembly Sort SortRegion ----*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// \brief This file implements regions used in CFGSort and CFGStackify. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYSORTREGION_H +#define LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYSORTREGION_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/iterator_range.h" + +namespace llvm { + +class MachineBasicBlock; +class MachineLoop; +class MachineLoopInfo; +class WebAssemblyException; +class WebAssemblyExceptionInfo; + +namespace WebAssembly { + +// Wrapper for loops and exceptions +class SortRegion { +public: + virtual ~SortRegion() = default; + virtual MachineBasicBlock *getHeader() const = 0; + virtual bool contains(const MachineBasicBlock *MBB) const = 0; + virtual unsigned getNumBlocks() const = 0; + using block_iterator = typename ArrayRef::const_iterator; + virtual iterator_range blocks() const = 0; + virtual bool isLoop() const = 0; +}; + +template class ConcreteSortRegion : public SortRegion { + const T *Unit; + +public: + ConcreteSortRegion(const T *Unit) : Unit(Unit) {} + MachineBasicBlock *getHeader() const override { return Unit->getHeader(); } + bool contains(const MachineBasicBlock *MBB) const override { + return Unit->contains(MBB); + } + unsigned getNumBlocks() const override { return Unit->getNumBlocks(); } + iterator_range blocks() const override { + return Unit->blocks(); + } + bool isLoop() const override { return false; } +}; + +// This class has information of nested SortRegions; this is analogous to what +// LoopInfo is for loops. +class SortRegionInfo { + friend class ConcreteSortRegion; + friend class ConcreteSortRegion; + + const MachineLoopInfo &MLI; + const WebAssemblyExceptionInfo &WEI; + DenseMap> LoopMap; + DenseMap> + ExceptionMap; + +public: + SortRegionInfo(const MachineLoopInfo &MLI, + const WebAssemblyExceptionInfo &WEI) + : MLI(MLI), WEI(WEI) {} + + // Returns a smallest loop or exception that contains MBB + const SortRegion *getRegionFor(const MachineBasicBlock *MBB); + + // Return the "bottom" block among all blocks dominated by the region + // (MachineLoop or WebAssemblyException) header. This works when the entity is + // discontiguous. + MachineBasicBlock *getBottom(const SortRegion *R); + MachineBasicBlock *getBottom(const MachineLoop *ML); + MachineBasicBlock *getBottom(const WebAssemblyException *WE); +}; + +} // end namespace WebAssembly + +} // end namespace llvm + +#endif diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySortRegion.cpp b/llvm/lib/Target/WebAssembly/WebAssemblySortRegion.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblySortRegion.cpp @@ -0,0 +1,73 @@ +#include "WebAssemblySortRegion.h" +#include "WebAssemblyExceptionInfo.h" +#include "llvm/CodeGen/MachineLoopInfo.h" + +using namespace llvm; +using namespace WebAssembly; + +template <> bool ConcreteSortRegion::isLoop() const { + return true; +} + +const SortRegion *SortRegionInfo::getRegionFor(const MachineBasicBlock *MBB) { + const auto *ML = MLI.getLoopFor(MBB); + const auto *WE = WEI.getExceptionFor(MBB); + if (!ML && !WE) + return nullptr; + // We determine subregion relationship by domination of their headers, i.e., + // if region A's header dominates region B's header, B is a subregion of A. + // WebAssemblyException contains BBs in all its subregions (loops or + // exceptions), but MachineLoop may not, because MachineLoop does not + // contain BBs that don't have a path to its header even if they are + // dominated by its header. So here we should use + // WE->contains(ML->getHeader()), but not ML->contains(WE->getHeader()). + if ((ML && !WE) || (ML && WE && WE->contains(ML->getHeader()))) { + // If the smallest region containing MBB is a loop + if (LoopMap.count(ML)) + return LoopMap[ML].get(); + LoopMap[ML] = std::make_unique>(ML); + return LoopMap[ML].get(); + } else { + // If the smallest region containing MBB is an exception + if (ExceptionMap.count(WE)) + return ExceptionMap[WE].get(); + ExceptionMap[WE] = + std::make_unique>(WE); + return ExceptionMap[WE].get(); + } +} + +MachineBasicBlock *SortRegionInfo::getBottom(const SortRegion *R) { + if (R->isLoop()) + return getBottom(MLI.getLoopFor(R->getHeader())); + else + return getBottom(WEI.getExceptionFor(R->getHeader())); +} + +MachineBasicBlock *SortRegionInfo::getBottom(const MachineLoop *ML) { + MachineBasicBlock *Bottom = ML->getHeader(); + for (MachineBasicBlock *MBB : ML->blocks()) { + if (MBB->getNumber() > Bottom->getNumber()) + Bottom = MBB; + // MachineLoop does not contain all BBs dominated by its header. BBs that + // don't have a path back to the loop header aren't included. But for the + // purpose of CFG sorting and stackification, we need a bottom BB among all + // BBs that are dominated by the loop header. So we check if there is any + // WebAssemblyException contained in this loop, and computes the most bottom + // BB of them all. + if (MBB->isEHPad()) { + MachineBasicBlock *ExBottom = getBottom(WEI.getExceptionFor(MBB)); + if (ExBottom->getNumber() > Bottom->getNumber()) + Bottom = ExBottom; + } + } + return Bottom; +} + +MachineBasicBlock *SortRegionInfo::getBottom(const WebAssemblyException *WE) { + MachineBasicBlock *Bottom = WE->getHeader(); + for (MachineBasicBlock *MBB : WE->blocks()) + if (MBB->getNumber() > Bottom->getNumber()) + Bottom = MBB; + return Bottom; +} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.h b/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.h --- a/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.h @@ -16,6 +16,7 @@ #define LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYUTILITIES_H #include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunction.h" namespace llvm { @@ -33,17 +34,6 @@ extern const char *const StdTerminateFn; extern const char *const PersonalityWrapperFn; -/// Return the "bottom" block of an entity, which can be either a MachineLoop or -/// WebAssemblyException. This differs from MachineLoop::getBottomBlock in that -/// it works even if the entity is discontiguous. -template MachineBasicBlock *getBottom(const T *Unit) { - MachineBasicBlock *Bottom = Unit->getHeader(); - for (MachineBasicBlock *MBB : Unit->blocks()) - if (MBB->getNumber() > Bottom->getNumber()) - Bottom = MBB; - return Bottom; -} - /// Returns the operand number of a callee, assuming the argument is a call /// instruction. const MachineOperand &getCalleeOp(const MachineInstr &MI); diff --git a/llvm/test/CodeGen/WebAssembly/cfg-stackify-eh.ll b/llvm/test/CodeGen/WebAssembly/cfg-stackify-eh.ll --- a/llvm/test/CodeGen/WebAssembly/cfg-stackify-eh.ll +++ b/llvm/test/CodeGen/WebAssembly/cfg-stackify-eh.ll @@ -975,6 +975,54 @@ ret void } +; Here an exception is semantically contained in a loop. 'ehcleanup' BB belongs +; to the exception, but does not belong to the loop (because it does not have a +; path back to the loop header), and is placed after the loop latch block +; 'invoke.cont' intentionally. This tests if 'end_loop' marker is placed +; correctly not right after 'invoke.cont' part but after 'ehcleanup' part, +; NOSORT-LABEL: test18 +; NOSORT: loop +; NOSORT: try +; NOSORT: end_try +; NOSORT: end_loop +define void @test18(i32 %n) personality i8* bitcast (i32 (...)* @__gxx_wasm_personality_v0 to i8*) { +entry: + br label %while.cond + +while.cond: ; preds = %invoke.cont, %entry + %n.addr.0 = phi i32 [ %n, %entry ], [ %dec, %invoke.cont ] + %tobool = icmp ne i32 %n.addr.0, 0 + br i1 %tobool, label %while.body, label %while.end + +while.body: ; preds = %while.cond + %dec = add nsw i32 %n.addr.0, -1 + invoke void @foo() + to label %while.end unwind label %catch.dispatch + +catch.dispatch: ; preds = %while.body + %0 = catchswitch within none [label %catch.start] unwind to caller + +catch.start: ; preds = %catch.dispatch + %1 = catchpad within %0 [i8* null] + %2 = call i8* @llvm.wasm.get.exception(token %1) + %3 = call i32 @llvm.wasm.get.ehselector(token %1) + %4 = call i8* @__cxa_begin_catch(i8* %2) [ "funclet"(token %1) ] + invoke void @__cxa_end_catch() [ "funclet"(token %1) ] + to label %invoke.cont unwind label %ehcleanup + +invoke.cont: ; preds = %catch.start + catchret from %1 to label %while.cond + +ehcleanup: ; preds = %catch.start + %5 = cleanuppad within %1 [] + %6 = call i8* @llvm.wasm.get.exception(token %5) + call void @__clang_call_terminate(i8* %6) [ "funclet"(token %5) ] + unreachable + +while.end: ; preds = %while.body, %while.cond + ret void +} + ; Check if the unwind destination mismatch stats are correct ; NOSORT-STAT: 17 wasm-cfg-stackify - Number of EH pad unwind mismatches found