diff --git a/lld/ELF/Driver.cpp b/lld/ELF/Driver.cpp --- a/lld/ELF/Driver.cpp +++ b/lld/ELF/Driver.cpp @@ -1522,7 +1522,7 @@ // Update pointers in input files. parallelForEach(ObjectFiles, [&](InputFile *File) { - std::vector &Syms = File->getMutableSymbols(); + MutableArrayRef Syms = File->getMutableSymbols(); for (size_t I = 0, E = Syms.size(); I != E; ++I) if (Symbol *S = Map.lookup(Syms[I])) Syms[I] = S; diff --git a/lld/ELF/InputFiles.h b/lld/ELF/InputFiles.h --- a/lld/ELF/InputFiles.h +++ b/lld/ELF/InputFiles.h @@ -90,7 +90,7 @@ // function on files of other types. ArrayRef getSymbols() { return getMutableSymbols(); } - std::vector &getMutableSymbols() { + MutableArrayRef getMutableSymbols() { assert(FileKind == BinaryKind || FileKind == ObjKind || FileKind == BitcodeKind); return Symbols; diff --git a/lld/include/lld/Common/LLVM.h b/lld/include/lld/Common/LLVM.h --- a/lld/include/lld/Common/LLVM.h +++ b/lld/include/lld/Common/LLVM.h @@ -29,6 +29,7 @@ class MemoryBuffer; class MemoryBufferRef; template class ArrayRef; +template class MutableArrayRef; template class SmallString; template class SmallVector; template class ErrorOr; @@ -62,6 +63,7 @@ // ADT's. using llvm::ArrayRef; +using llvm::MutableArrayRef; using llvm::Error; using llvm::ErrorOr; using llvm::Expected; diff --git a/lld/test/wasm/wrap.ll b/lld/test/wasm/wrap.ll new file mode 100644 --- /dev/null +++ b/lld/test/wasm/wrap.ll @@ -0,0 +1,32 @@ +; RUN: llc -filetype=obj %s -o %t.o +; RUN: wasm-ld -wrap nosuchsym -o %t.wasm %t.o +; RUN: wasm-ld -emit-relocs -wrap foo -o %t.wasm %t.o +; RUN: obj2yaml %t.wasm | FileCheck %s + +target triple = "wasm32-unknown-unknown" + +define i32 @foo() { + ret i32 1 +} + +define void @_start() { +entry: + call i32 @foo() + ret void +} + +define i32 @__wrap_foo() { + ret i32 2 +} + +; CHECK: - Type: CODE +; CHECK-NEXT: Relocations: +; CHECK-NEXT: - Type: R_WASM_FUNCTION_INDEX_LEB +; CHECK-NEXT: Index: 1 +; CHECK-NEXT: Offset: 0x00000004 + +; CHECK: FunctionNames: +; CHECK-NEXT: - Index: 0 +; CHECK-NEXT: Name: _start +; CHECK-NEXT: - Index: 1 +; CHECK-NEXT: Name: __wrap_foo diff --git a/lld/wasm/Driver.cpp b/lld/wasm/Driver.cpp --- a/lld/wasm/Driver.cpp +++ b/lld/wasm/Driver.cpp @@ -535,6 +535,83 @@ return Data.str(); } +// The --wrap option is a feature to rename symbols so that you can write +// wrappers for existing functions. If you pass `-wrap=foo`, all +// occurrences of symbol `foo` are resolved to `wrap_foo` (so, you are +// expected to write `wrap_foo` function as a wrapper). The original +// symbol becomes accessible as `real_foo`, so you can call that from your +// wrapper. +// +// This data structure is instantiated for each -wrap option. +struct WrappedSymbol { + Symbol *Sym; + Symbol *Real; + Symbol *Wrap; +}; + +static Symbol *addUndefined(StringRef Name) { + return Symtab->addUndefinedFunction(Name, "", "", 0, nullptr, nullptr); +} + +// Handles -wrap option. +// +// This function instantiates wrapper symbols. At this point, they seem +// like they are not being used at all, so we explicitly set some flags so +// that LTO won't eliminate them. +static std::vector addWrappedSymbols(opt::InputArgList &Args) { + std::vector V; + DenseSet Seen; + + for (auto *Arg : Args.filtered(OPT_wrap)) { + StringRef Name = Arg->getValue(); + if (!Seen.insert(Name).second) + continue; + + Symbol *Sym = Symtab->find(Name); + if (!Sym) + continue; + + Symbol *Real = addUndefined(Saver.save("__real_" + Name)); + Symbol *Wrap = addUndefined(Saver.save("__wrap_" + Name)); + V.push_back({Sym, Real, Wrap}); + + // We want to tell LTO not to inline symbols to be overwritten + // because LTO doesn't know the final symbol contents after renaming. + Real->CanInline = false; + Sym->CanInline = false; + + // Tell LTO not to eliminate these symbols. + Sym->IsUsedInRegularObj = true; + Wrap->IsUsedInRegularObj = true; + } + return V; +} + +// Do renaming for -wrap by updating pointers to symbols. +// +// When this function is executed, only InputFiles and symbol table +// contain pointers to symbol objects. We visit them to replace pointers, +// so that wrapped symbols are swapped as instructed by the command line. +static void wrapSymbols(ArrayRef Wrapped) { + DenseMap Map; + for (const WrappedSymbol &W : Wrapped) { + Map[W.Sym] = W.Wrap; + Map[W.Real] = W.Sym; + } + + // Update pointers in input files. + parallelForEach(Symtab->ObjectFiles, [&](InputFile *File) { + MutableArrayRef Syms = File->getMutableSymbols(); + for (size_t I = 0, E = Syms.size(); I != E; ++I) + if (Symbol *S = Map.lookup(Syms[I])) + Syms[I] = S; + }); + + // Update pointers in the symbol table. + for (const WrappedSymbol &W : Wrapped) + Symtab->wrap(W.Sym, W.Real, W.Wrap); +} + void LinkerDriver::link(ArrayRef ArgsArr) { WasmOptTable Parser; opt::InputArgList Args = Parser.parse(ArgsArr.slice(1)); @@ -628,6 +705,9 @@ for (auto *Arg : Args.filtered(OPT_export)) handleUndefined(Arg->getValue()); + // Create wrapped symbols for -wrap option. + std::vector Wrapped = addWrappedSymbols(Args); + // Do link-time optimization if given files are LLVM bitcode files. // This compiles bitcode files into real object files. Symtab->addCombinedLTOObject(); @@ -640,6 +720,10 @@ if (errorCount()) return; + // Apply symbol renames for -wrap. + if (!Wrapped.empty()) + wrapSymbols(Wrapped); + for (auto *Arg : Args.filtered(OPT_export)) { Symbol *Sym = Symtab->find(Arg->getValue()); if (Sym && Sym->isDefined()) diff --git a/lld/wasm/InputFiles.h b/lld/wasm/InputFiles.h --- a/lld/wasm/InputFiles.h +++ b/lld/wasm/InputFiles.h @@ -61,6 +61,8 @@ ArrayRef getSymbols() const { return Symbols; } + MutableArrayRef getMutableSymbols() { return Symbols; } + protected: InputFile(Kind K, MemoryBufferRef M) : MB(M), FileKind(K) {} MemoryBufferRef MB; diff --git a/lld/wasm/LTO.cpp b/lld/wasm/LTO.cpp --- a/lld/wasm/LTO.cpp +++ b/lld/wasm/LTO.cpp @@ -108,6 +108,11 @@ (R.Prevailing && Sym->isExported()); if (R.Prevailing) undefine(Sym); + + // We tell LTO to not apply interprocedural optimization for wrapped + // (with --wrap) symbols because otherwise LTO would inline them while + // their values are still not final. + R.LinkerRedefined = !Sym->CanInline; } checkError(LTOObj->add(std::move(F.Obj), Resols)); } diff --git a/lld/wasm/Options.td b/lld/wasm/Options.td --- a/lld/wasm/Options.td +++ b/lld/wasm/Options.td @@ -112,6 +112,9 @@ def z: JoinedOrSeparate<["-"], "z">, MetaVarName<"