diff --git a/clang/lib/CodeGen/CGCXX.cpp b/clang/lib/CodeGen/CGCXX.cpp --- a/clang/lib/CodeGen/CGCXX.cpp +++ b/clang/lib/CodeGen/CGCXX.cpp @@ -40,6 +40,11 @@ if (getCodeGenOpts().OptimizationLevel == 0) return true; + // Disable this optimization for ARM64EC. FIXME: This probably should work, + // but getting the symbol table correct is complicated. + if (getTarget().getTriple().isWindowsArm64EC()) + return true; + // If sanitizing memory to check for use-after-dtor, do not emit as // an alias, unless this class owns no members. if (getCodeGenOpts().SanitizeMemoryUseAfterDtor && diff --git a/llvm/include/llvm/IR/CallingConv.h b/llvm/include/llvm/IR/CallingConv.h --- a/llvm/include/llvm/IR/CallingConv.h +++ b/llvm/include/llvm/IR/CallingConv.h @@ -245,6 +245,16 @@ /// placement. Preserves active lane values for input VGPRs. AMDGPU_CS_ChainPreserve = 105, + /// Calling convention used in the ARM64EC ABI to implement calls between + /// x64 code and thunks. This is basically the x64 calling convention using + /// ARM64 register names. The first parameter is mapped to x9. + ARM64EC_Thunk_X64 = 106, + + /// Calling convention used in the ARM64EC ABI to implement calls between + /// ARM64 code and thunks. This is just the ARM64 calling convention, + /// except that the first parameter is mapped to x9. + ARM64EC_Thunk_Native = 107, + /// The highest possible ID. Must be some 2^k - 1. MaxID = 1023 }; diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -2714,6 +2714,36 @@ GV->hasAvailableExternallyLinkage()) return true; + if (GV->getName() == "llvm.arm64ec.symbolmap") { + OutStreamer->switchSection(OutContext.getCOFFSection( + ".hybmp$x", COFF::IMAGE_SCN_LNK_INFO, SectionKind::getMetadata())); + auto *Arr = cast(GV->getInitializer()); + for (auto &U : Arr->operands()) { + auto *C = cast(U); + auto *Src = cast(C->getOperand(0)->stripPointerCasts()); + auto *Dst = cast(C->getOperand(1)->stripPointerCasts()); + int Kind = cast(C->getOperand(2))->getZExtValue(); + + if (Src->hasDLLImportStorageClass()) { + // For now, we assume dllimport functions aren't directly called. + // (We might change this later to match MSVC.) + OutStreamer->emitCOFFSymbolIndex( + OutContext.getOrCreateSymbol("__imp_" + Src->getName())); + OutStreamer->emitCOFFSymbolIndex(getSymbol(Dst)); + OutStreamer->emitInt32(Kind); + } else { + // FIXME: For non-dllimport functions, MSVC emits the same entry + // twice, for reasons I don't understand. I have to assume the linker + // ignores the redundant entry; there aren't any reasonable semantics + // to attach to it. + OutStreamer->emitCOFFSymbolIndex(getSymbol(Src)); + OutStreamer->emitCOFFSymbolIndex(getSymbol(Dst)); + OutStreamer->emitInt32(Kind); + } + } + return true; + } + if (!GV->hasAppendingLinkage()) return false; assert(GV->hasInitializer() && "Not a special LLVM global!"); diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h --- a/llvm/lib/Target/AArch64/AArch64.h +++ b/llvm/lib/Target/AArch64/AArch64.h @@ -71,6 +71,7 @@ FunctionPass *createAArch64StackTaggingPass(bool IsOptNone); FunctionPass *createAArch64StackTaggingPreRAPass(); ModulePass *createAArch64GlobalsTaggingPass(); +ModulePass *createAArch64Arm64ECCallLoweringPass(); void initializeAArch64A53Fix835769Pass(PassRegistry&); void initializeAArch64A57FPLoadBalancingPass(PassRegistry&); @@ -108,6 +109,7 @@ void initializeLDTLSCleanupPass(PassRegistry&); void initializeSMEABIPass(PassRegistry &); void initializeSVEIntrinsicOptsPass(PassRegistry &); +void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &); } // end namespace llvm #endif diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp @@ -0,0 +1,752 @@ +//===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the IR transform to lower external or indirect calls for +/// the ARM64EC calling convention. Such calls must go through the runtime, so +/// we can translate the calling convention for calls into the emulator. +/// +/// This subsumes Control Flow Guard handling. +/// +//===----------------------------------------------------------------------===// + +#include "AArch64.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/TargetParser/Triple.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" + +using namespace llvm; + +using OperandBundleDef = OperandBundleDefT; + +#define DEBUG_TYPE "arm64eccalllowering" + +STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered"); + +static cl::opt LowerDirectToIndirect( + "arm64ec-lower-direct-to-indirect", cl::Hidden, cl::init(true)); +static cl::opt GenerateThunks( + "arm64ec-generate-thunks", cl::Hidden, cl::init(true)); + +namespace { + +class AArch64Arm64ECCallLowering : public ModulePass { +public: + static char ID; + AArch64Arm64ECCallLowering() : ModulePass(ID) { + initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry()); + } + + Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs); + Function *buildEntryThunk(Function *F); + void lowerCall(CallBase *CB); + Function *buildGuestExitThunk(Function *F); + bool processFunction(Function &F, + SetVector &DirectCalledFns); + bool runOnModule(Module &M) override; + +private: + int cfguard_module_flag = 0; + FunctionType *GuardFnType = nullptr; + PointerType *GuardFnPtrType = nullptr; + Constant *GuardFnCFGlobal = nullptr; + Constant *GuardFnGlobal = nullptr; + Module *M = nullptr; + + Type *I8PtrTy; + Type *I64Ty; + Type *VoidTy; + + void getThunkType(FunctionType *FT, AttributeList AttrList, bool EntryThunk, + raw_ostream &Out, FunctionType *&Arm64Ty, + FunctionType *&X64Ty); + void getThunkRetType(FunctionType *FT, AttributeList AttrList, + raw_ostream &Out, Type *&Arm64RetTy, + Type *&X64RetTy, SmallVectorImpl &Arm64ArgTypes, + SmallVectorImpl &X64ArgTypes); + void getThunkArgTypes(FunctionType *FT, AttributeList AttrList, + raw_ostream &Out, + SmallVectorImpl &Arm64ArgTypes, + SmallVectorImpl &X64ArgTypes); + void canonicalizeThunkType(Type *T, Align Alignment, + bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out, + Type *&Arm64Ty, Type *&X64Ty); +}; + +} // end anonymous namespace + +void AArch64Arm64ECCallLowering::getThunkType(FunctionType *FT, + AttributeList AttrList, + bool EntryThunk, raw_ostream &Out, + FunctionType *&Arm64Ty, + FunctionType *&X64Ty) { + Out << (EntryThunk ? "$ientry_thunk$cdecl$" : "$iexit_thunk$cdecl$"); + + Type *Arm64RetTy; + Type *X64RetTy; + + SmallVector Arm64ArgTypes; + SmallVector X64ArgTypes; + + // The first argument to a thunk is the called function, stored in x9. + // For exit thunks, we pass the called function down to the emulator; + // for entry thunks, we just call the Arm64 function directly. + if (!EntryThunk) + Arm64ArgTypes.push_back(I8PtrTy); + X64ArgTypes.push_back(I8PtrTy); + + getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, + Arm64ArgTypes, X64ArgTypes); + + getThunkArgTypes(FT, AttrList, Out, Arm64ArgTypes, X64ArgTypes); + + Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false); + X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false); +} + +void AArch64Arm64ECCallLowering::getThunkArgTypes( + FunctionType *FT, AttributeList AttrList, raw_ostream &Out, + SmallVectorImpl &Arm64ArgTypes, + SmallVectorImpl &X64ArgTypes) { + bool HasSretPtr = Arm64ArgTypes.size() > 1; + + Out << "$"; + if (FT->isVarArg()) { + // We treat the variadic function's thunk as a normal function + // with the following type on the ARM side: + // rettype exitthunk( + // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5) + // + // that can coverage all types of variadic function. + // x9 is similar to normal exit thunk, store the called function. + // x0-x3 is the arguments be stored in registers. + // x4 is the address of the arguments on the stack. + // x5 is the size of the arguments on the stack. + // + // On the x64 side, it's the same except that x5 isn't set. + // + // If both the ARM and X64 sides are sret, there are only three + // arguments in registers. + // + // If the X64 side is sret, but the ARM side isn't, we pass an extra value + // to/from the X64 side, and let SelectionDAG transform it into a memory + // location. + Out << "varargs"; + + // x0-x3 + for (int i = HasSretPtr ? 1 : 0; i < 4; i++) { + Arm64ArgTypes.push_back(I64Ty); + X64ArgTypes.push_back(I64Ty); + } + + // x4 + Arm64ArgTypes.push_back(I8PtrTy); + X64ArgTypes.push_back(I8PtrTy); + // x5 + Arm64ArgTypes.push_back(I64Ty); + // FIXME: x5 isn't actually passed/used by the x64 side; revisit once we + // have proper isel for varargs + X64ArgTypes.push_back(I64Ty); + return; + } + + unsigned I = 0; + if (HasSretPtr) + I++; + + if (I == FT->getNumParams()) { + Out << "v"; + return; + } + + for (unsigned E = FT->getNumParams(); I != E; ++I) { + Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne(); +#if 0 + // FIXME: Need more information about argument size; see + // https://reviews.llvm.org/D132926 + uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I); +#else + uint64_t ArgSizeBytes = 0; +#endif + Type *Arm64Ty, *X64Ty; + canonicalizeThunkType(FT->getParamType(I), ParamAlign, + /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty); + Arm64ArgTypes.push_back(Arm64Ty); + X64ArgTypes.push_back(X64Ty); + } +} + +void AArch64Arm64ECCallLowering::getThunkRetType( + FunctionType *FT, AttributeList AttrList, raw_ostream &Out, + Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl &Arm64ArgTypes, + SmallVectorImpl &X64ArgTypes) { + Type *T = FT->getReturnType(); +#if 0 + // FIXME: Need more information about argument size; see + // https://reviews.llvm.org/D132926 + uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes(); +#else + int64_t ArgSizeBytes = 0; +#endif + if (T->isVoidTy()) { + if (FT->getNumParams()) { + auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet); + auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg); + if (SRetAttr.isValid() && InRegAttr.isValid()) { + // sret+inreg indicates a call that returns a C++ class value. This is + // actually equivalent to just passing and returning a void* pointer + // as the first argument. Translate it that way, instead of trying + // to model "inreg" in the thunk's calling convention, to simplify + // the rest of the code. + Out << "i8"; + Arm64RetTy = I64Ty; + X64RetTy = I64Ty; + return; + } + if (SRetAttr.isValid()) { + Type *SRetType = SRetAttr.getValueAsType(); + Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne(); + Type *Arm64Ty, *X64Ty; + canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, + ArgSizeBytes, Out, Arm64Ty, X64Ty); + Arm64RetTy = VoidTy; + X64RetTy = VoidTy; + Arm64ArgTypes.push_back(FT->getParamType(0)); + X64ArgTypes.push_back(FT->getParamType(0)); + return; + } + } + + Out << "v"; + Arm64RetTy = VoidTy; + X64RetTy = VoidTy; + return; + } + + canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, + Arm64RetTy, X64RetTy); + if (X64RetTy->isPointerTy()) { + // If the X64 type is canonicalized to a pointer, that means it's + // passed/returned indirectly. For a return value, that means it's an + // sret pointer. + X64ArgTypes.push_back(X64RetTy); + X64RetTy = VoidTy; + } +} + +void AArch64Arm64ECCallLowering::canonicalizeThunkType( + Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, + raw_ostream &Out, Type *&Arm64Ty, Type *&X64Ty) { + if (T->isFloatTy()) { + Out << "f"; + Arm64Ty = T; + X64Ty = T; + return; + } + + if (T->isDoubleTy()) { + Out << "d"; + Arm64Ty = T; + X64Ty = T; + return; + } + + auto &DL = M->getDataLayout(); + + if (auto *StructTy = dyn_cast(T)) + if (StructTy->getNumElements() == 1) + T = StructTy->getElementType(0); + + if (T->isArrayTy()) { + Type *ElementTy = T->getArrayElementType(); + uint64_t ElementCnt = T->getArrayNumElements(); + uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8; + uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes; + if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) { + Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes; + if (Alignment.value() >= 8 && !T->isPointerTy()) + Out << "a" << Alignment.value(); + Arm64Ty = T; + if (TotalSizeBytes <= 8) { + // Arm64 returns small structs of float/double in float registers; + // X64 uses RAX. + X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8); + } else { + // Struct is passed directly on Arm64, but indirectly on X64. + X64Ty = Arm64Ty->getPointerTo(0); + } + return; + } + } + + if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) { + Out << "i8"; + Arm64Ty = I64Ty; + X64Ty = I64Ty; + return; + } + + unsigned TypeSize = ArgSizeBytes; + if (TypeSize == 0) + TypeSize = DL.getTypeSizeInBits(T) / 8; + Out << "m"; + if (TypeSize != 4) + Out << TypeSize; + if (Alignment.value() >= 8 && !T->isPointerTy()) + Out << "a" << Alignment.value(); + // FIXME: Try to canonicalize Arm64Ty more thoroughly? + Arm64Ty = T; + if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) { + // Pass directly in an integer register + X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8); + } else { + // Passed directly on Arm64, but indirectly on X64. + X64Ty = Arm64Ty->getPointerTo(0); + } +} + +Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT, + AttributeList Attrs) { + SmallString<256> ExitThunkName; + llvm::raw_svector_ostream ExitThunkStream(ExitThunkName); + FunctionType *Arm64Ty, *X64Ty; + getThunkType(FT, Attrs, /*EntryThunk*/ false, ExitThunkStream, Arm64Ty, + X64Ty); + if (Function *F = M->getFunction(ExitThunkName)) + return F; + + Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0, + ExitThunkName, M); + F->setCallingConv(CallingConv::ARM64EC_Thunk_Native); + F->setSection(".wowthk$aa"); + F->setComdat(M->getOrInsertComdat(ExitThunkName)); + // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) + F->addFnAttr("frame-pointer", "all"); + // Only copy sret from the first argument. For C++ instance methods, clang can + // stick an sret marking on a later argument, but it doesn't actually affect + // the ABI, so we can omit it. This avoids triggering a verifier assertion. + if (FT->getNumParams()) { + auto SRet = Attrs.getParamAttr(0, Attribute::StructRet); + auto InReg = Attrs.getParamAttr(0, Attribute::InReg); + if (SRet.isValid() && !InReg.isValid()) + F->addParamAttr(1, SRet); + } + // FIXME: Copy anything other than sret? Shouldn't be necessary for normal + // C ABI, but might show up in other cases. + BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F); + IRBuilder<> IRB(BB); + PointerType *DispatchPtrTy = + FunctionType::get(IRB.getVoidTy(), false)->getPointerTo(0); + Value *CalleePtr = M->getOrInsertGlobal( + "__os_arm64x_dispatch_call_no_redirect", DispatchPtrTy); + Value *Callee = IRB.CreateLoad(DispatchPtrTy, CalleePtr); + auto &DL = M->getDataLayout(); + SmallVector Args; + + // Pass the called function in x9. + Args.push_back(F->arg_begin()); + + Type *RetTy = Arm64Ty->getReturnType(); + if (RetTy != X64Ty->getReturnType()) { + // If the return type is an array or struct, translate it. Values of size + // 8 or less go into RAX; bigger values go into memory, and we pass a + // pointer. + if (DL.getTypeStoreSize(RetTy) > 8) { + Args.push_back(IRB.CreateAlloca(RetTy)); + } + } + + for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) { + // Translate arguments from AArch64 calling convention to x86 calling + // convention. + // + // For simple types, we don't need to do any translation: they're + // represented the same way. (Implicit sign extension is not part of + // either convention.) + // + // The big thing we have to worry about is struct types... but + // fortunately AArch64 clang is pretty friendly here: the cases that need + // translation are always passed as a struct or array. (If we run into + // some cases where this doesn't work, we can teach clang to mark it up + // with an attribute.) + // + // The first argument is the called function, stored in x9. + if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() || + DL.getTypeStoreSize(Arg.getType()) > 8) { + Value *Mem = IRB.CreateAlloca(Arg.getType()); + IRB.CreateStore(&Arg, Mem); + if (DL.getTypeStoreSize(Arg.getType()) <= 8) { + Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType())); + Args.push_back(IRB.CreateLoad( + IntTy, IRB.CreateBitCast(Mem, IntTy->getPointerTo(0)))); + } else + Args.push_back(Mem); + } else { + Args.push_back(&Arg); + } + } + // FIXME: Transfer necessary attributes? sret? anything else? + + Callee = IRB.CreateBitCast(Callee, X64Ty->getPointerTo(0)); + CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args); + Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64); + + Value *RetVal = Call; + if (RetTy != X64Ty->getReturnType()) { + // If we rewrote the return type earlier, convert the return value to + // the proper type. + if (DL.getTypeStoreSize(RetTy) > 8) { + RetVal = IRB.CreateLoad(RetTy, Args[1]); + } else { + Value *CastAlloca = IRB.CreateAlloca(RetTy); + IRB.CreateStore(Call, IRB.CreateBitCast( + CastAlloca, Call->getType()->getPointerTo(0))); + RetVal = IRB.CreateLoad(RetTy, CastAlloca); + } + } + + if (RetTy->isVoidTy()) + IRB.CreateRetVoid(); + else + IRB.CreateRet(RetVal); + return F; +} + +Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) { + SmallString<256> EntryThunkName; + llvm::raw_svector_ostream EntryThunkStream(EntryThunkName); + FunctionType *Arm64Ty, *X64Ty; + getThunkType(F->getFunctionType(), F->getAttributes(), /*EntryThunk*/ true, + EntryThunkStream, Arm64Ty, X64Ty); + if (Function *F = M->getFunction(EntryThunkName)) + return F; + + Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0, + EntryThunkName, M); + Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64); + Thunk->setSection(".wowthk$aa"); + Thunk->setComdat(M->getOrInsertComdat(EntryThunkName)); + // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) + Thunk->addFnAttr("frame-pointer", "all"); + + auto &DL = M->getDataLayout(); + BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk); + IRBuilder<> IRB(BB); + + Type *RetTy = Arm64Ty->getReturnType(); + Type *X64RetType = X64Ty->getReturnType(); + + bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy(); + unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1; + + // Translate arguments to call. + SmallVector Args; + for (unsigned i = ThunkArgOffset, e = Thunk->arg_size(); i != e; ++i) { + Value *Arg = Thunk->getArg(i); + Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset); + if (ArgTy->isArrayTy() || ArgTy->isStructTy() || + DL.getTypeStoreSize(ArgTy) > 8) { + // Translate array/struct arguments to the expected type. + if (DL.getTypeStoreSize(ArgTy) <= 8) { + Value *CastAlloca = IRB.CreateAlloca(ArgTy); + IRB.CreateStore(Arg, IRB.CreateBitCast( + CastAlloca, Arg->getType()->getPointerTo(0))); + Arg = IRB.CreateLoad(ArgTy, CastAlloca); + } else { + Arg = IRB.CreateLoad(ArgTy, + IRB.CreateBitCast(Arg, ArgTy->getPointerTo(0))); + } + } + Args.push_back(Arg); + } + + // Call the function passed to the thunk. + Value *Callee = Thunk->getArg(0); + Callee = IRB.CreateBitCast(Callee, Arm64Ty->getPointerTo(0)); + Value *Call = IRB.CreateCall(Arm64Ty, Callee, Args); + + Value *RetVal = Call; + if (TransformDirectToSRet) { + IRB.CreateStore(RetVal, + IRB.CreateBitCast(Thunk->getArg(1), + RetVal->getType()->getPointerTo(0))); + } else if (X64RetType != RetTy) { + Value *CastAlloca = IRB.CreateAlloca(X64RetType); + IRB.CreateStore( + Call, IRB.CreateBitCast(CastAlloca, Call->getType()->getPointerTo(0))); + RetVal = IRB.CreateLoad(X64RetType, CastAlloca); + } + + // Return to the caller. Note that the isel has code to translate this + // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we + // could emit a tail call here, but that would require a dedicated calling + // convention, which seems more complicated overall.) + if (X64RetType->isVoidTy()) + IRB.CreateRetVoid(); + else + IRB.CreateRet(RetVal); + + return Thunk; +} + +Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) { + llvm::raw_null_ostream NullThunkName; + FunctionType *Arm64Ty, *X64Ty; + getThunkType(F->getFunctionType(), F->getAttributes(), /*EntryThunk*/ true, + NullThunkName, Arm64Ty, X64Ty); + auto MangledName = getArm64ECMangledFunctionName(F->getName().str()); + assert(MangledName && + "Can't guest exit to function that's already native"); + std::string ThunkName = *MangledName; + if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) { + ThunkName.insert(ThunkName.find("@"), "$exit_thunk"); + } else { + ThunkName.append("$exit_thunk"); + } + assert(MangledName && "Can't guest exit to function that's already native"); + Function *GuestExit = Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, + ThunkName, M); + GuestExit->setComdat(M->getOrInsertComdat(ThunkName)); + GuestExit->setSection(".wowthk$aa"); + GuestExit->setMetadata("arm64ec_unmangled_name", MDNode::get(M->getContext(), MDString::get(M->getContext(), F->getName()))); + GuestExit->setMetadata("arm64ec_ecmangled_name", MDNode::get(M->getContext(), MDString::get(M->getContext(), *MangledName))); + F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {})); + BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit); + IRBuilder<> B(BB); + + // Load the global symbol as a pointer to the check function. + Value *GuardFn; + if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf")) + GuardFn = GuardFnCFGlobal; + else + GuardFn = GuardFnGlobal; + LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn); + + // Create new call instruction. The CFGuard check should always be a call, + // even if the original CallBase is an Invoke or CallBr instruction. + Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes()); + CallInst *GuardCheck = + B.CreateCall(GuardFnType, GuardCheckLoad, + {B.CreateBitCast(F, B.getInt8PtrTy()), + B.CreateBitCast(Thunk, B.getInt8PtrTy())}); + + // Ensure that the first argument is passed in the correct register + // (e.g. ECX on 32-bit X86 targets). + GuardCheck->setCallingConv(CallingConv::CFGuard_Check); + + Value *GuardRetVal = B.CreateBitCast(GuardCheck, Arm64Ty->getPointerTo(0)); + SmallVector Args; + for (Argument &Arg : GuestExit->args()) + Args.push_back(&Arg); + CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args); + Call->setTailCallKind(llvm::CallInst::TCK_MustTail); + + if (Call->getType()->isVoidTy()) + B.CreateRetVoid(); + else + B.CreateRet(Call); + + auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); + auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); + if (SRetAttr.isValid() && !InRegAttr.isValid()) { + GuestExit->addParamAttr(0, SRetAttr); + Call->addParamAttr(0, SRetAttr); + } + + return GuestExit; +} + +void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) { + assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() && + "Only applicable for Windows targets"); + + IRBuilder<> B(CB); + Value *CalledOperand = CB->getCalledOperand(); + + // If the indirect call is called within catchpad or cleanuppad, + // we need to copy "funclet" bundle of the call. + SmallVector Bundles; + if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet)) + Bundles.push_back(OperandBundleDef(*Bundle)); + + // Load the global symbol as a pointer to the check function. + Value *GuardFn; + if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf")) + GuardFn = GuardFnCFGlobal; + else + GuardFn = GuardFnGlobal; + LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn); + + // Create new call instruction. The CFGuard check should always be a call, + // even if the original CallBase is an Invoke or CallBr instruction. + Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes()); + CallInst *GuardCheck = + B.CreateCall(GuardFnType, GuardCheckLoad, + {B.CreateBitCast(CalledOperand, B.getInt8PtrTy()), + B.CreateBitCast(Thunk, B.getInt8PtrTy())}, + Bundles); + + // Ensure that the first argument is passed in the correct register + // (e.g. ECX on 32-bit X86 targets). + GuardCheck->setCallingConv(CallingConv::CFGuard_Check); + + Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType()); + CB->setCalledOperand(GuardRetVal); +} + +bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { + if (!GenerateThunks) + return false; + + M = &Mod; + + // Check if this module has the cfguard flag and read its value. + if (auto *MD = + mdconst::extract_or_null(M->getModuleFlag("cfguard"))) + cfguard_module_flag = MD->getZExtValue(); + + I8PtrTy = Type::getInt8PtrTy(M->getContext()); + I64Ty = Type::getInt64Ty(M->getContext()); + VoidTy = Type::getVoidTy(M->getContext()); + + GuardFnType = FunctionType::get(I8PtrTy, {I8PtrTy, I8PtrTy}, false); + GuardFnPtrType = PointerType::get(GuardFnType, 0); + GuardFnCFGlobal = + M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType); + GuardFnGlobal = + M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType); + + SetVector DirectCalledFns; + for (Function &F : Mod) + if (!F.isDeclaration() && + F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && + F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) + processFunction(F, DirectCalledFns); + + struct ThunkInfo { + Constant *Src; + Constant *Dst; + unsigned Kind; + }; + SmallVector ThunkMapping; + for (Function &F : Mod) { + if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) && + F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && + F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) { + if (!F.hasComdat()) + F.setComdat(Mod.getOrInsertComdat(F.getName())); + ThunkMapping.push_back({&F, buildEntryThunk(&F), 1}); + } + } + for (Function *F : DirectCalledFns) { + ThunkMapping.push_back( + {F, buildExitThunk(F->getFunctionType(), F->getAttributes()), 4}); + if (!F->hasDLLImportStorageClass()) + ThunkMapping.push_back({buildGuestExitThunk(F), F, 0}); + } + + if (!ThunkMapping.empty()) { + Type *VoidPtr = Type::getInt8PtrTy(M->getContext()); + SmallVector ThunkMappingArrayElems; + for (ThunkInfo &Thunk : ThunkMapping) { + ThunkMappingArrayElems.push_back(ConstantStruct::getAnon( + {ConstantExpr::getBitCast(Thunk.Src, VoidPtr), + ConstantExpr::getBitCast(Thunk.Dst, VoidPtr), + ConstantInt::get(M->getContext(), APInt(32, Thunk.Kind))})); + } + Constant *ThunkMappingArray = ConstantArray::get( + llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(), + ThunkMappingArrayElems.size()), + ThunkMappingArrayElems); + new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false, + GlobalValue::ExternalLinkage, ThunkMappingArray, + "llvm.arm64ec.symbolmap"); + } + + return true; +} + +bool AArch64Arm64ECCallLowering::processFunction( + Function &F, SetVector &DirectCalledFns) { + SmallVector IndirectCalls; + + // For ARM64EC targets, a function definition's name is mangled differently + // from the normal symbol. We currently have no representation of this sort + // of symbol in IR, so we change the name to the mangled name, then store + // the unmangled name as metadata. Later passes that need the unmangled + // name (emitting the definition) can grab it from the metadata. + // + // FIXME: Handle functions with weak linkage? + if (F.hasExternalLinkage() || F.hasWeakLinkage() || F.hasLinkOnceLinkage()) { + if (std::optional MangledName = + getArm64ECMangledFunctionName(F.getName().str())) { + F.setMetadata("arm64ec_unmangled_name", MDNode::get(M->getContext(), MDString::get(M->getContext(), F.getName()))); + if (F.hasComdat() && F.getComdat()->getName() == F.getName()) { + Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value()); + SmallVector ComdatUsers = to_vector(F.getComdat()->getUsers()); + for (GlobalObject *User : ComdatUsers) + User->setComdat(MangledComdat); + } + F.setName(MangledName.value()); + } + } + + // Iterate over the instructions to find all indirect call/invoke/callbr + // instructions. Make a separate list of pointers to indirect + // call/invoke/callbr instructions because the original instructions will be + // deleted as the checks are added. + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + auto *CB = dyn_cast(&I); + if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 || + CB->isInlineAsm()) + continue; + + // We need to instrument any call that isn't directly calling an + // ARM64 function. + // + // FIXME: getCalledFunction() fails if there's a bitcast (e.g. + // unprototyped functions in C) + if (Function *F = CB->getCalledFunction()) { + if (!LowerDirectToIndirect || F->hasLocalLinkage() || F->isIntrinsic() || + !F->isDeclaration()) + continue; + + DirectCalledFns.insert(F); + continue; + } + + IndirectCalls.push_back(CB); + ++Arm64ECCallsLowered; + } + } + + if (IndirectCalls.empty()) + return false; + + for (CallBase *CB : IndirectCalls) + lowerCall(CB); + + return true; +} + +char AArch64Arm64ECCallLowering::ID = 0; +INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering", + "AArch64Arm64ECCallLowering", false, false) + +ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() { + return new AArch64Arm64ECCallLowering; +} diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -164,6 +164,8 @@ return false; } + const MCExpr *lowerConstant(const Constant *CV) override; + private: void printOperand(const MachineInstr *MI, unsigned OpNum, raw_ostream &O); bool printAsmMRegister(const MachineOperand &MO, char Mode, raw_ostream &O); @@ -1102,6 +1104,48 @@ TS->emitDirectiveVariantPCS(CurrentFnSym); } + if (TM.getTargetTriple().isWindowsArm64EC()) { + // For ARM64EC targets, a function definition's name is mangled differently + // from the normal symbol. We emit the alias from the unmangled symbol to + // mangled symbol name here. + if (MDNode *Unmangled = MF->getFunction().getMetadata("arm64ec_unmangled_name")) { + AsmPrinter::emitFunctionEntryLabel(); + + if (MDNode *ECMangled = MF->getFunction().getMetadata("arm64ec_ecmangled_name")) { + StringRef UnmangledStr = + cast(Unmangled->getOperand(0))->getString(); + MCSymbol *UnmangledSym = + MMI->getContext().getOrCreateSymbol(UnmangledStr); + StringRef ECMangledStr = + cast(ECMangled->getOperand(0))->getString(); + MCSymbol *ECMangledSym = + MMI->getContext().getOrCreateSymbol(ECMangledStr); + OutStreamer->emitSymbolAttribute(UnmangledSym, MCSA_WeakAntiDep); + OutStreamer->emitAssignment( + UnmangledSym, + MCSymbolRefExpr::create(ECMangledSym, MCSymbolRefExpr::VK_WEAKREF, + MMI->getContext())); + OutStreamer->emitSymbolAttribute(ECMangledSym, MCSA_WeakAntiDep); + OutStreamer->emitAssignment( + ECMangledSym, + MCSymbolRefExpr::create(CurrentFnSym, MCSymbolRefExpr::VK_WEAKREF, + MMI->getContext())); + return; + } else { + StringRef UnmangledStr = + cast(Unmangled->getOperand(0))->getString(); + MCSymbol *UnmangledSym = + MMI->getContext().getOrCreateSymbol(UnmangledStr); + OutStreamer->emitSymbolAttribute(UnmangledSym, MCSA_WeakAntiDep); + OutStreamer->emitAssignment( + UnmangledSym, + MCSymbolRefExpr::create(CurrentFnSym, MCSymbolRefExpr::VK_WEAKREF, + MMI->getContext())); + return; + } + } + } + return AsmPrinter::emitFunctionEntryLabel(); } @@ -1801,6 +1845,28 @@ case AArch64::SEH_PACSignLR: TS->emitARM64WinCFIPACSignLR(); return; + + case AArch64::SEH_SaveAnyRegQP: + assert(MI->getOperand(1).getImm() - MI->getOperand(0).getImm() == 1 && + "Non-consecutive registers not allowed for save_any_reg"); + assert(MI->getOperand(2).getImm() >= 0 && + "SaveAnyRegQP SEH opcode offset must be non-negative"); + assert(MI->getOperand(2).getImm() <= 1008 && + "SaveAnyRegQP SEH opcode offset must fit into 6 bits"); + TS->emitARM64WinCFISaveAnyRegQP(MI->getOperand(0).getImm(), + MI->getOperand(2).getImm()); + return; + + case AArch64::SEH_SaveAnyRegQPX: + assert(MI->getOperand(1).getImm() - MI->getOperand(0).getImm() == 1 && + "Non-consecutive registers not allowed for save_any_reg"); + assert(MI->getOperand(2).getImm() < 0 && + "SaveAnyRegQPX SEH opcode offset must be negative"); + assert(MI->getOperand(2).getImm() >= -1008 && + "SaveAnyRegQPX SEH opcode offset must fit into 6 bits"); + TS->emitARM64WinCFISaveAnyRegQPX(MI->getOperand(0).getImm(), + -MI->getOperand(2).getImm()); + return; } // Finally, do the automated lowerings for everything else. @@ -1809,6 +1875,15 @@ EmitToStreamer(*OutStreamer, TmpInst); } +const MCExpr *AArch64AsmPrinter::lowerConstant(const Constant *CV) { + if (const GlobalValue *GV = dyn_cast(CV)) { + return MCSymbolRefExpr::create(MCInstLowering.GetGlobalValueSymbol(GV, 0), + OutContext); + } + + return AsmPrinter::lowerConstant(CV); +} + // Force static initialization. extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64AsmPrinter() { RegisterAsmPrinter X(getTheAArch64leTarget()); diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.h b/llvm/lib/Target/AArch64/AArch64CallingConvention.h --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.h +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.h @@ -22,6 +22,12 @@ bool CC_AArch64_Arm64EC_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State); +bool CC_AArch64_Arm64EC_Thunk(unsigned ValNo, MVT ValVT, MVT LocVT, + CCValAssign::LocInfo LocInfo, + ISD::ArgFlagsTy ArgFlags, CCState &State); +bool CC_AArch64_Arm64EC_Thunk_Native(unsigned ValNo, MVT ValVT, MVT LocVT, + CCValAssign::LocInfo LocInfo, + ISD::ArgFlagsTy ArgFlags, CCState &State); bool CC_AArch64_DarwinPCS_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State); @@ -37,6 +43,9 @@ bool CC_AArch64_Win64_CFGuard_Check(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State); +bool CC_AArch64_Arm64EC_CFGuard_Check(unsigned ValNo, MVT ValVT, MVT LocVT, + CCValAssign::LocInfo LocInfo, + ISD::ArgFlagsTy ArgFlags, CCState &State); bool CC_AArch64_WebKit_JS(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State); @@ -46,6 +55,13 @@ bool RetCC_AArch64_AAPCS(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State); +bool RetCC_AArch64_Arm64EC_Thunk(unsigned ValNo, MVT ValVT, MVT LocVT, + CCValAssign::LocInfo LocInfo, + ISD::ArgFlagsTy ArgFlags, CCState &State); +bool RetCC_AArch64_Arm64EC_CFGuard_Check(unsigned ValNo, MVT ValVT, MVT LocVT, + CCValAssign::LocInfo LocInfo, + ISD::ArgFlagsTy ArgFlags, + CCState &State); bool RetCC_AArch64_WebKit_JS(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State); diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -202,6 +202,119 @@ CCIfType<[i32, i64], CCAssignToStack<8, 8>>, ]>; +// Arm64EC thunks use a calling convention that's precisely the x64 calling +// convention, except that the registers have different names, and the callee +// address is passed in X9. +let Entry = 1 in +def CC_AArch64_Arm64EC_Thunk : CallingConv<[ + // Byval aggregates are passed by pointer + CCIfByVal>, + + // ARM64EC-specific: promote small integers to i32. (x86 only promotes i1, + // but that would confuse ARM64 lowering code.) + CCIfType<[i1, i8, i16], CCPromoteToType>, + + // The 'nest' parameter, if any, is passed in R10 (X4). + CCIfNest>, + + // A SwiftError is passed in R12 (X19). + CCIfSwiftError>>, + + // Pass SwiftSelf in R13 (X20). + CCIfSwiftSelf>>, + + // Pass SwiftAsync in an otherwise callee saved register so that calls to + // normal functions don't need to save it somewhere. + CCIfSwiftAsync>>, + + // The 'CFGuardTarget' parameter, if any, is passed in RAX (R8). + CCIfCFGuardTarget>, + + // 128 bit vectors are passed by pointer + CCIfType<[v16i8, v8i16, v4i32, v2i64, v8f16, v4f32, v2f64], CCPassIndirect>, + + // 256 bit vectors are passed by pointer + CCIfType<[v32i8, v16i16, v8i32, v4i64, v16f16, v8f32, v4f64], CCPassIndirect>, + + // 512 bit vectors are passed by pointer + CCIfType<[v64i8, v32i16, v16i32, v32f16, v16f32, v8f64, v8i64], CCPassIndirect>, + + // Long doubles are passed by pointer + CCIfType<[f80], CCPassIndirect>, + + // The first 4 MMX vector arguments are passed in GPRs. + CCIfType<[x86mmx], CCBitConvertToType>, + + // The first 4 FP/Vector arguments are passed in XMM registers. + CCIfType<[f16], + CCAssignToRegWithShadow<[H0, H1, H2, H3], + [X0, X1, X2, X3]>>, + CCIfType<[f32], + CCAssignToRegWithShadow<[S0, S1, S2, S3], + [X0, X1, X2, X3]>>, + CCIfType<[f64], + CCAssignToRegWithShadow<[D0, D1, D2, D3], + [X0, X1, X2, X3]>>, + + // The first 4 integer arguments are passed in integer registers. + CCIfType<[i32], CCAssignToRegWithShadow<[W0, W1, W2, W3], + [Q0, Q1, Q2, Q3]>>, + + // Arm64EC thunks: the first argument is always a pointer to the destination + // address, stored in x9. + CCIfType<[i64], CCAssignToReg<[X9]>>, + + CCIfType<[i64], CCAssignToRegWithShadow<[X0, X1, X2, X3], + [Q0, Q1, Q2, Q3]>>, + + // Integer/FP values get stored in stack slots that are 8 bytes in size and + // 8-byte aligned if there are no more registers to hold them. + CCIfType<[i8, i16, i32, i64, f16, f32, f64], CCAssignToStack<8, 8>> +]>; + +// The native side of ARM64EC thunks +let Entry = 1 in +def CC_AArch64_Arm64EC_Thunk_Native : CallingConv<[ + CCIfType<[i64], CCAssignToReg<[X9]>>, + CCDelegateTo +]>; + +let Entry = 1 in +def RetCC_AArch64_Arm64EC_Thunk : CallingConv<[ + // The X86-Win64 calling convention always returns __m64 values in RAX. + CCIfType<[x86mmx], CCBitConvertToType>, + + // Otherwise, everything is the same as 'normal' X86-64 C CC. + + // The X86-64 calling convention always returns FP values in XMM0. + CCIfType<[f16], CCAssignToReg<[H0, H1]>>, + CCIfType<[f32], CCAssignToReg<[S0, S1]>>, + CCIfType<[f64], CCAssignToReg<[D0, D1]>>, + CCIfType<[f128], CCAssignToReg<[Q0, Q1]>>, + + CCIfSwiftError>>, + + // Scalar values are returned in AX first, then DX. For i8, the ABI + // requires the values to be in AL and AH, however this code uses AL and DL + // instead. This is because using AH for the second register conflicts with + // the way LLVM does multiple return values -- a return of {i16,i8} would end + // up in AX and AH, which overlap. Front-ends wishing to conform to the ABI + // for functions that return two i8 values are currently expected to pack the + // values into an i16 (which uses AX, and thus AL:AH). + // + // For code that doesn't care about the ABI, we allow returning more than two + // integer values in registers. + CCIfType<[i1, i8, i16], CCPromoteToType>, + CCIfType<[i32], CCAssignToReg<[W8, W1, W0]>>, + CCIfType<[i64], CCAssignToReg<[X8, X1, X0]>>, + + // Vector types are returned in XMM0 and XMM1, when they fit. XMM2 and XMM3 + // can only be used by ABI non-compliant code. If the target doesn't have XMM + // registers, it won't have vector types. + CCIfType<[v16i8, v8i16, v4i32, v2i64, v8f16, v4f32, v2f64], + CCAssignToReg<[Q0, Q1, Q2, Q3]>> +]>; + // Windows Control Flow Guard checks take a single argument (the target function // address) and have no return value. let Entry = 1 in @@ -209,6 +322,16 @@ CCIfType<[i64], CCAssignToReg<[X15]>> ]>; +let Entry = 1 in +def CC_AArch64_Arm64EC_CFGuard_Check : CallingConv<[ + CCIfType<[i64], CCAssignToReg<[X11, X10]>> +]>; + +let Entry = 1 in +def RetCC_AArch64_Arm64EC_CFGuard_Check : CallingConv<[ + CCIfType<[i64], CCAssignToReg<[X11]>> +]>; + // Darwin uses a calling convention which differs in only two ways // from the standard one at this level: @@ -428,6 +551,11 @@ (sequence "X%u", 0, 8), (sequence "Q%u", 0, 7))>; +// To match the x64 calling convention, Arm64EC thunks preserve q6-q15. +def CSR_Win_AArch64_Arm64EC_Thunk : CalleeSavedRegs<(add (sequence "Q%u", 6, 15), + X19, X20, X21, X22, X23, X24, + X25, X26, X27, X28, FP, LR)>; + // AArch64 PCS for vector functions (VPCS) // must (additionally) preserve full Q8-Q23 registers def CSR_AArch64_AAVPCS : CalleeSavedRegs<(add X19, X20, X21, X22, X23, X24, diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp --- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp +++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp @@ -3183,6 +3183,9 @@ if (IsVarArg) return false; + if (Subtarget->isWindowsArm64EC()) + return false; + for (auto Flag : CLI.OutFlags) if (Flag.isInReg() || Flag.isSRet() || Flag.isNest() || Flag.isByVal() || Flag.isSwiftSelf() || Flag.isSwiftAsync() || Flag.isSwiftError()) diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -1081,6 +1081,30 @@ .setMIFlag(Flag); break; } + case AArch64::STPQi: + case AArch64::LDPQi: { + unsigned Reg0 = RegInfo->getSEHRegNum(MBBI->getOperand(0).getReg()); + unsigned Reg1 = RegInfo->getSEHRegNum(MBBI->getOperand(1).getReg()); + MIB = BuildMI(MF, DL, TII.get(AArch64::SEH_SaveAnyRegQP)) + .addImm(Reg0) + .addImm(Reg1) + .addImm(Imm * 16) + .setMIFlag(Flag); + break; + } + case AArch64::LDPQpost: + Imm = -Imm; + LLVM_FALLTHROUGH; + case AArch64::STPQpre: { + unsigned Reg0 = RegInfo->getSEHRegNum(MBBI->getOperand(1).getReg()); + unsigned Reg1 = RegInfo->getSEHRegNum(MBBI->getOperand(2).getReg()); + MIB = BuildMI(MF, DL, TII.get(AArch64::SEH_SaveAnyRegQPX)) + .addImm(Reg0) + .addImm(Reg1) + .addImm(Imm * 16) + .setMIFlag(Flag); + break; + } } auto I = MBB->insertAfter(MBBI, MIB); return I; @@ -1099,6 +1123,8 @@ case AArch64::SEH_SaveReg: case AArch64::SEH_SaveFRegP: case AArch64::SEH_SaveFReg: + case AArch64::SEH_SaveAnyRegQP: + case AArch64::SEH_SaveAnyRegQPX: ImmOpnd = &MBBI->getOperand(ImmIdx); break; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -62,6 +62,9 @@ SMSTOP, RESTORE_ZA, + // A call with the callee in x16, i.e. "blr x16". + CALL_ARM64EC_TO_X64, + // Produces the full sequence of instructions for getting the thread pointer // offset of a variable into X0, using the TLSDesc model. TLSDESC_CALLSEQ, @@ -1019,6 +1022,8 @@ unsigned Flag) const; SDValue getTargetNode(BlockAddressSDNode *N, EVT Ty, SelectionDAG &DAG, unsigned Flag) const; + SDValue getTargetNode(ExternalSymbolSDNode *N, EVT Ty, SelectionDAG &DAG, + unsigned Flag) const; template SDValue getGOT(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const; template diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1639,6 +1639,44 @@ PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive(); IsStrictFPEnabled = true; + + if (Subtarget->isWindowsArm64EC()) { + // FIXME: are there other intrinsics we need to add here? + setLibcallName(RTLIB::MEMCPY, "#memcpy"); + setLibcallName(RTLIB::MEMSET, "#memset"); + setLibcallName(RTLIB::MEMMOVE, "#memmove"); + setLibcallName(RTLIB::REM_F32, "#fmodf"); + setLibcallName(RTLIB::REM_F64, "#fmod"); + setLibcallName(RTLIB::FMA_F32, "#fmaf"); + setLibcallName(RTLIB::FMA_F64, "#fma"); + setLibcallName(RTLIB::SQRT_F32, "#sqrtf"); + setLibcallName(RTLIB::SQRT_F64, "#sqrt"); + setLibcallName(RTLIB::CBRT_F32, "#cbrtf"); + setLibcallName(RTLIB::CBRT_F64, "#cbrt"); + setLibcallName(RTLIB::LOG_F32, "#logf"); + setLibcallName(RTLIB::LOG_F64, "#log"); + setLibcallName(RTLIB::LOG2_F32, "#log2f"); + setLibcallName(RTLIB::LOG2_F64, "#log2"); + setLibcallName(RTLIB::LOG10_F32, "#log10f"); + setLibcallName(RTLIB::LOG10_F64, "#log10"); + setLibcallName(RTLIB::EXP_F32, "#expf"); + setLibcallName(RTLIB::EXP_F64, "#exp"); + setLibcallName(RTLIB::EXP2_F32, "#exp2f"); + setLibcallName(RTLIB::EXP2_F64, "#exp2"); + setLibcallName(RTLIB::EXP10_F32, "#exp10f"); + setLibcallName(RTLIB::EXP10_F64, "#exp10"); + setLibcallName(RTLIB::SIN_F32, "#sinf"); + setLibcallName(RTLIB::SIN_F64, "#sin"); + setLibcallName(RTLIB::COS_F32, "#cosf"); + setLibcallName(RTLIB::COS_F64, "#cos"); + setLibcallName(RTLIB::POW_F32, "#powf"); + setLibcallName(RTLIB::POW_F64, "#pow"); + setLibcallName(RTLIB::LDEXP_F32, "#ldexpf"); + setLibcallName(RTLIB::LDEXP_F64, "#ldexp"); + setLibcallName(RTLIB::FREXP_F32, "#frexpf"); + setLibcallName(RTLIB::FREXP_F64, "#frexp"); + + } } void AArch64TargetLowering::addTypeForNEON(MVT VT) { @@ -2620,6 +2658,7 @@ MAKE_CASE(AArch64ISD::MRRS) MAKE_CASE(AArch64ISD::MSRR) MAKE_CASE(AArch64ISD::RSHRNB_I) + MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64) } #undef MAKE_CASE return nullptr; @@ -6333,19 +6372,35 @@ } return CC_AArch64_AAPCS; case CallingConv::CFGuard_Check: + if (Subtarget->isWindowsArm64EC()) + return CC_AArch64_Arm64EC_CFGuard_Check; return CC_AArch64_Win64_CFGuard_Check; case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: return CC_AArch64_AAPCS; + case CallingConv::ARM64EC_Thunk_X64: + return CC_AArch64_Arm64EC_Thunk; + case CallingConv::ARM64EC_Thunk_Native: + return CC_AArch64_Arm64EC_Thunk_Native; } } CCAssignFn * AArch64TargetLowering::CCAssignFnForReturn(CallingConv::ID CC) const { - return CC == CallingConv::WebKit_JS ? RetCC_AArch64_WebKit_JS - : RetCC_AArch64_AAPCS; + switch (CC) { + default: + return RetCC_AArch64_AAPCS; + case CallingConv::WebKit_JS: + return RetCC_AArch64_WebKit_JS; + case CallingConv::ARM64EC_Thunk_X64: + return RetCC_AArch64_Arm64EC_Thunk; + case CallingConv::CFGuard_Check: + if (Subtarget->isWindowsArm64EC()) + return RetCC_AArch64_Arm64EC_CFGuard_Check; + return RetCC_AArch64_AAPCS; + } } @@ -6386,6 +6441,8 @@ const Function &F = MF.getFunction(); MachineFrameInfo &MFI = MF.getFrameInfo(); bool IsWin64 = Subtarget->isCallingConvWin64(F.getCallingConv()); + bool StackViaX4 = CallConv == CallingConv::ARM64EC_Thunk_X64 || + (isVarArg && Subtarget->isWindowsArm64EC()); AArch64FunctionInfo *FuncInfo = MF.getInfo(); SmallVector Outs; @@ -6555,10 +6612,14 @@ SDValue FIN; MachinePointerInfo PtrInfo; - if (isVarArg && Subtarget->isWindowsArm64EC()) { - // In the ARM64EC varargs convention, fixed arguments on the stack are - // accessed relative to x4, not sp. + if (StackViaX4) { + // In both the ARM64EC varargs convention and the thunk convention, + // arguments on the stack are accessed relative to x4, not sp. In + // the thunk convention, there's an additional offset of 32 bytes + // to account for the shadow store. unsigned ObjOffset = ArgOffset + BEAlign; + if (CallConv == CallingConv::ARM64EC_Thunk_X64) + ObjOffset += 32; Register VReg = MF.addLiveIn(AArch64::X4, &AArch64::GPR64RegClass); SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::i64); FIN = DAG.getNode(ISD::ADD, DL, MVT::i64, Val, @@ -6725,9 +6786,11 @@ // On Windows, InReg pointers must be returned, so record the pointer in a // virtual register at the start of the function so it can be returned in the // epilogue. - if (IsWin64) { + if (IsWin64 || F.getCallingConv() == CallingConv::ARM64EC_Thunk_X64) { for (unsigned I = 0, E = Ins.size(); I != E; ++I) { - if (Ins[I].Flags.isInReg() && Ins[I].Flags.isSRet()) { + if ((F.getCallingConv() == CallingConv::ARM64EC_Thunk_X64 || + Ins[I].Flags.isInReg()) && + Ins[I].Flags.isSRet()) { assert(!FuncInfo->getSRetReturnReg()); MVT PtrTy = getPointerTy(DAG.getDataLayout()); @@ -6958,6 +7021,11 @@ const SmallVector &Outs = CLI.Outs; bool IsCalleeWin64 = Subtarget->isCallingConvWin64(CalleeCC); + // For Arm64EC thunks, allocate 32 extra bytes at the bottom of the stack + // for the shadow store. + if (CalleeCC == CallingConv::ARM64EC_Thunk_X64) + CCInfo.AllocateStack(32, Align(16)); + unsigned NumArgs = Outs.size(); for (unsigned i = 0; i != NumArgs; ++i) { MVT ArgVT = Outs[i].VT; @@ -7670,7 +7738,7 @@ Callee = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, Callee); } else { const GlobalValue *GV = G->getGlobal(); - Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, 0); + Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, OpFlags); } } else if (auto *S = dyn_cast(Callee)) { if (getTargetMachine().getCodeModel() == CodeModel::Large && @@ -7765,8 +7833,11 @@ Function *ARCFn = *objcarc::getAttachedARCFunction(CLI.CB); auto GA = DAG.getTargetGlobalAddress(ARCFn, DL, PtrVT); Ops.insert(Ops.begin() + 1, GA); - } else if (GuardWithBTI) + } else if (CallConv == CallingConv::ARM64EC_Thunk_X64) { + CallOpc = AArch64ISD::CALL_ARM64EC_TO_X64; + } else if (GuardWithBTI) { CallOpc = AArch64ISD::CALL_BTI; + } // Returns a chain and a flag for retval copy to use. Chain = DAG.getNode(CallOpc, DL, NodeTys, Ops); @@ -7953,6 +8024,8 @@ getPointerTy(MF.getDataLayout())); unsigned RetValReg = AArch64::X0; + if (CallConv == CallingConv::ARM64EC_Thunk_X64) + RetValReg = AArch64::X8; Chain = DAG.getCopyToReg(Chain, DL, RetValReg, Val, Glue); Glue = Chain.getValue(1); @@ -7978,6 +8051,21 @@ if (Glue.getNode()) RetOps.push_back(Glue); + if (CallConv == CallingConv::ARM64EC_Thunk_X64) { + // ARM64EC entry thunks use a special return sequence: instead of a regular + // "ret" instruction, they need to explicitly call the emulator. + EVT PtrVT = getPointerTy(DAG.getDataLayout()); + SDValue Arm64ECRetDest = + DAG.getExternalSymbol("__os_arm64x_dispatch_ret", PtrVT); + Arm64ECRetDest = + getAddr(cast(Arm64ECRetDest), DAG, 0); + Arm64ECRetDest = DAG.getLoad(PtrVT, DL, DAG.getEntryNode(), Arm64ECRetDest, + MachinePointerInfo()); + RetOps.insert(RetOps.begin() + 1, Arm64ECRetDest); + RetOps.insert(RetOps.begin() + 2, DAG.getTargetConstant(0, DL, MVT::i32)); + return DAG.getNode(AArch64ISD::TC_RETURN, DL, MVT::Other, RetOps); + } + return DAG.getNode(AArch64ISD::RET_GLUE, DL, MVT::Other, RetOps); } @@ -8011,6 +8099,12 @@ return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, 0, Flag); } +SDValue AArch64TargetLowering::getTargetNode(ExternalSymbolSDNode *N, EVT Ty, + SelectionDAG &DAG, + unsigned Flag) const { + return DAG.getTargetExternalSymbol(N->getSymbol(), Ty, Flag); +} + // (loadGOT sym) template SDValue AArch64TargetLowering::getGOT(NodeTy *N, SelectionDAG &DAG, @@ -8091,8 +8185,7 @@ } EVT PtrVT = getPointerTy(DAG.getDataLayout()); SDLoc DL(GN); - if (OpFlags & (AArch64II::MO_DLLIMPORT | AArch64II::MO_DLLIMPORTAUX | - AArch64II::MO_COFFSTUB)) + if (OpFlags & (AArch64II::MO_DLLIMPORT | AArch64II::MO_COFFSTUB)) Result = DAG.getLoad(PtrVT, DL, DAG.getEntryNode(), Result, MachinePointerInfo::getGOT(DAG.getMachineFunction())); return Result; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -1066,6 +1066,8 @@ case AArch64::SEH_EpilogStart: case AArch64::SEH_EpilogEnd: case AArch64::SEH_PACSignLR: + case AArch64::SEH_SaveAnyRegQP: + case AArch64::SEH_SaveAnyRegQPX: return true; } } @@ -7984,9 +7986,10 @@ {MO_S, "aarch64-s"}, {MO_TLS, "aarch64-tls"}, {MO_DLLIMPORT, "aarch64-dllimport"}, - {MO_DLLIMPORTAUX, "aarch64-dllimportaux"}, {MO_PREL, "aarch64-prel"}, - {MO_TAGGED, "aarch64-tagged"}}; + {MO_TAGGED, "aarch64-tagged"}, + {MO_ARM64EC_CALLMANGLE, "aarch64-arm64ec-callmangle"}, + }; return ArrayRef(TargetFlags); } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -596,6 +596,11 @@ [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, SDNPVariadic]>; +def AArch64call_arm64ec_to_x64 : SDNode<"AArch64ISD::CALL_ARM64EC_TO_X64", + SDTypeProfile<0, -1, [SDTCisPtrTy<0>]>, + [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, + SDNPVariadic]>; + def AArch64brcond : SDNode<"AArch64ISD::BRCOND", SDT_AArch64Brcond, [SDNPHasChain]>; def AArch64cbz : SDNode<"AArch64ISD::CBZ", SDT_AArch64cbz, @@ -2670,6 +2675,10 @@ Sched<[WriteBrReg]>; def BLR_BTI : Pseudo<(outs), (ins variable_ops), []>, Sched<[WriteBrReg]>; + let Uses = [X16, SP] in + def BLR_X16 : Pseudo<(outs), (ins), [(AArch64call_arm64ec_to_x64 X16)]>, + Sched<[WriteBrReg]>, + PseudoInstExpansion<(BLR X16)>; } // isCall def : Pat<(AArch64call GPR64:$Rn), @@ -4655,6 +4664,8 @@ def SEH_EpilogStart : Pseudo<(outs), (ins), []>, Sched<[]>; def SEH_EpilogEnd : Pseudo<(outs), (ins), []>, Sched<[]>; def SEH_PACSignLR : Pseudo<(outs), (ins), []>, Sched<[]>; + def SEH_SaveAnyRegQP : Pseudo<(outs), (ins i32imm:$reg0, i32imm:$reg1, i32imm:$offs), []>, Sched<[]>; + def SEH_SaveAnyRegQPX : Pseudo<(outs), (ins i32imm:$reg0, i32imm:$reg1, i32imm:$offs), []>, Sched<[]>; } // Pseudo instructions for Windows EH diff --git a/llvm/lib/Target/AArch64/AArch64MCInstLower.h b/llvm/lib/Target/AArch64/AArch64MCInstLower.h --- a/llvm/lib/Target/AArch64/AArch64MCInstLower.h +++ b/llvm/lib/Target/AArch64/AArch64MCInstLower.h @@ -9,6 +9,7 @@ #ifndef LLVM_LIB_TARGET_AARCH64_AARCH64MCINSTLOWER_H #define LLVM_LIB_TARGET_AARCH64_AARCH64MCINSTLOWER_H +#include "llvm/IR/GlobalValue.h" #include "llvm/Support/Compiler.h" #include "llvm/TargetParser/Triple.h" @@ -42,6 +43,8 @@ MCSymbol *Sym) const; MCOperand LowerSymbolOperand(const MachineOperand &MO, MCSymbol *Sym) const; + MCSymbol *GetGlobalValueSymbol(const GlobalValue *GV, + unsigned TargetFlags) const; MCSymbol *GetGlobalAddressSymbol(const MachineOperand &MO) const; MCSymbol *GetExternalSymbolSymbol(const MachineOperand &MO) const; }; diff --git a/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp b/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp --- a/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp +++ b/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp @@ -36,8 +36,11 @@ MCSymbol * AArch64MCInstLower::GetGlobalAddressSymbol(const MachineOperand &MO) const { - const GlobalValue *GV = MO.getGlobal(); - unsigned TargetFlags = MO.getTargetFlags(); + return GetGlobalValueSymbol(MO.getGlobal(), MO.getTargetFlags()); +} + +MCSymbol *AArch64MCInstLower::GetGlobalValueSymbol(const GlobalValue *GV, + unsigned TargetFlags) const { const Triple &TheTriple = Printer.TM.getTargetTriple(); if (!TheTriple.isOSBinFormatCOFF()) return Printer.getSymbolPreferLocal(*GV); @@ -46,14 +49,54 @@ "Windows is the only supported COFF target"); bool IsIndirect = - (TargetFlags & (AArch64II::MO_DLLIMPORT | AArch64II::MO_DLLIMPORTAUX | - AArch64II::MO_COFFSTUB)); - if (!IsIndirect) + (TargetFlags & (AArch64II::MO_DLLIMPORT | AArch64II::MO_COFFSTUB)); + if (!IsIndirect) { + // For ARM64EC, symbol lookup in the MSVC linker has limited awareness + // of ARM64EC mangling ("#"/"$$h"). So object files need to refer to both + // the mangled and unmangled names of ARM64EC symbols, even if they aren't + // actually used by any relocations. Emit the necessary references here. + if (!TheTriple.isWindowsArm64EC() || !isa(GV) || + !GV->hasExternalLinkage()) + return Printer.getSymbol(GV); + + StringRef Name = Printer.getSymbol(GV)->getName(); + // Don't mangle ARM64EC runtime functions. + static constexpr StringLiteral ExcludedFns[] = { + "__os_arm64x_check_icall_cfg", "__os_arm64x_dispatch_call_no_redirect", + "__os_arm64x_check_icall"}; + if (is_contained(ExcludedFns, Name)) + return Printer.getSymbol(GV); + + if (std::optional MangledName = + getArm64ECMangledFunctionName(Name.str())) { + MCSymbol *MangledSym = Ctx.getOrCreateSymbol(MangledName.value()); + if (!cast(GV)->hasMetadata("arm64ec_hasguestexit")) { + Printer.OutStreamer->emitSymbolAttribute(Printer.getSymbol(GV), + MCSA_WeakAntiDep); + Printer.OutStreamer->emitAssignment( + Printer.getSymbol(GV), + MCSymbolRefExpr::create(MangledSym, MCSymbolRefExpr::VK_WEAKREF, + Ctx)); + Printer.OutStreamer->emitSymbolAttribute(MangledSym, MCSA_WeakAntiDep); + Printer.OutStreamer->emitAssignment( + MangledSym, + MCSymbolRefExpr::create(Printer.getSymbol(GV), + MCSymbolRefExpr::VK_WEAKREF, Ctx)); + } + + if (TargetFlags & AArch64II::MO_ARM64EC_CALLMANGLE) + return MangledSym; + } + return Printer.getSymbol(GV); + } SmallString<128> Name; - if (TargetFlags & AArch64II::MO_DLLIMPORTAUX) { + if ((TargetFlags & AArch64II::MO_DLLIMPORT) && + TheTriple.isWindowsArm64EC() && + !(TargetFlags & AArch64II::MO_ARM64EC_CALLMANGLE) && + isa(GV)) { // __imp_aux is specific to arm64EC; it represents the actual address of // an imported function without any thunks. // diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -78,6 +78,9 @@ if (MF->getFunction().getCallingConv() == CallingConv::AnyReg) return CSR_AArch64_AllRegs_SaveList; + if (MF->getFunction().getCallingConv() == CallingConv::ARM64EC_Thunk_X64) + return CSR_Win_AArch64_Arm64EC_Thunk_SaveList; + // Darwin has its own CSR_AArch64_AAPCS_SaveList, which means most CSR save // lists depending on that will need to have their Darwin variant as well. if (MF->getSubtarget().isTargetDarwin()) diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h --- a/llvm/lib/Target/AArch64/AArch64Subtarget.h +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h @@ -423,13 +423,13 @@ const char* getChkStkName() const { if (isWindowsArm64EC()) - return "__chkstk_arm64ec"; + return "#__chkstk_arm64ec"; return "__chkstk"; } const char* getSecurityCheckCookieName() const { if (isWindowsArm64EC()) - return "__security_check_cookie_arm64ec"; + return "#__security_check_cookie_arm64ec"; return "__security_check_cookie"; } }; diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp --- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp @@ -376,8 +376,6 @@ if (!TM.shouldAssumeDSOLocal(*GV->getParent(), GV)) { if (GV->hasDLLImportStorageClass()) { - if (isWindowsArm64EC() && GV->getValueType()->isFunctionTy()) - return AArch64II::MO_GOT | AArch64II::MO_DLLIMPORTAUX; return AArch64II::MO_GOT | AArch64II::MO_DLLIMPORT; } if (getTargetTriple().isOSWindows()) @@ -417,11 +415,18 @@ return AArch64II::MO_GOT; if (getTargetTriple().isOSWindows()) { - if (isWindowsArm64EC() && GV->getValueType()->isFunctionTy() && - GV->hasDLLImportStorageClass()) { - // On Arm64EC, if we're calling a function directly, use MO_DLLIMPORT, - // not MO_DLLIMPORTAUX. - return AArch64II::MO_GOT | AArch64II::MO_DLLIMPORT; + if (isWindowsArm64EC() && GV->getValueType()->isFunctionTy()) { + if (GV->hasDLLImportStorageClass()) { + // On Arm64EC, if we're calling a symbol from the import table + // directly, use MO_ARM64EC_CALLMANGLE. + return AArch64II::MO_GOT | AArch64II::MO_DLLIMPORT | + AArch64II::MO_ARM64EC_CALLMANGLE; + } + if (GV->hasExternalLinkage()) { + // If we're calling a symbol directly, use the mangled form in the + // call instruction. + return AArch64II::MO_ARM64EC_CALLMANGLE; + } } // Use ClassifyGlobalReference for setting MO_DLLIMPORT/MO_COFFSTUB. diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -619,8 +619,12 @@ addPass(createSMEABIPass()); // Add Control Flow Guard checks. - if (TM->getTargetTriple().isOSWindows()) - addPass(createCFGuardCheckPass()); + if (TM->getTargetTriple().isOSWindows()) { + if (TM->getTargetTriple().isWindowsArm64EC()) + addPass(createAArch64Arm64ECCallLoweringPass()); + else + addPass(createCFGuardCheckPass()); + } if (TM->Options.JMCInstrument) addPass(createJMCInstrumenterPass()); diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt --- a/llvm/lib/Target/AArch64/CMakeLists.txt +++ b/llvm/lib/Target/AArch64/CMakeLists.txt @@ -42,6 +42,7 @@ GISel/AArch64RegisterBankInfo.cpp AArch64A57FPLoadBalancing.cpp AArch64AdvSIMDScalarPass.cpp + AArch64Arm64ECCallLowering.cpp AArch64AsmPrinter.cpp AArch64BranchTargets.cpp AArch64CallingConvention.cpp diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp --- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp @@ -634,7 +634,18 @@ MachineRegisterInfo &MRI = MF.getRegInfo(); auto &DL = F.getParent()->getDataLayout(); auto &Subtarget = MF.getSubtarget(); - // TODO: Support Arm64EC + + // Arm64EC has extra requirements for varargs calls which are only implemented + // in SelectionDAG; bail out for now. + if (F.isVarArg() && Subtarget.isWindowsArm64EC()) + return false; + + // Arm64EC thunks have a special calling convention which is only implemented + // in SelectionDAG; bail out for now. + if (F.getCallingConv() == CallingConv::ARM64EC_Thunk_Native || + F.getCallingConv() == CallingConv::ARM64EC_Thunk_X64) + return false; + bool IsWin64 = Subtarget.isCallingConvWin64(F.getCallingConv()) && !Subtarget.isWindowsArm64EC(); SmallVector SplitArgs; @@ -1199,7 +1210,16 @@ const AArch64Subtarget &Subtarget = MF.getSubtarget(); // Arm64EC has extra requirements for varargs calls; bail out for now. - if (Info.IsVarArg && Subtarget.isWindowsArm64EC()) + // + // Arm64EC has special mangling rules for calls; bail out on all calls for + // now. + if (Subtarget.isWindowsArm64EC()) + return false; + + // Arm64EC thunks have a special calling convention which is only implemented + // in SelectionDAG; bail out for now. + if (Info.CallConv == CallingConv::ARM64EC_Thunk_Native || + Info.CallConv == CallingConv::ARM64EC_Thunk_X64) return false; SmallVector OutArgs; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -248,6 +248,34 @@ return false; } +static inline std::optional +getArm64ECMangledFunctionName(std::string Name) { + bool IsCppFn = Name[0] == '?'; + if (IsCppFn && Name.find("$$h") != std::string::npos) + return std::nullopt; + if (!IsCppFn && Name[0] == '#') + return std::nullopt; + + StringRef Prefix = "$$h"; + size_t InsertIdx = 0; + if (IsCppFn) { + InsertIdx = Name.find("@@"); + size_t ThreeAtSignsIdx = Name.find("@@@"); + if (InsertIdx != std::string::npos && InsertIdx != ThreeAtSignsIdx) { + InsertIdx += 2; + } else { + InsertIdx = Name.find("@"); + if (InsertIdx != std::string::npos) + InsertIdx++; + } + } else { + Prefix = "#"; + } + + Name.insert(Name.begin() + InsertIdx, Prefix.begin(), Prefix.end()); + return std::optional(Name); +} + namespace AArch64CC { // The CondCodes constants map directly to the 4-bit encoding of the condition @@ -795,12 +823,11 @@ /// an LDG instruction to obtain the tag value. MO_TAGGED = 0x400, - /// MO_DLLIMPORTAUX - Symbol refers to "auxilliary" import stub. On - /// Arm64EC, there are two kinds of import stubs used for DLL import of - /// functions: MO_DLLIMPORT refers to natively callable Arm64 code, and - /// MO_DLLIMPORTAUX refers to the original address which can be compared - /// for equality. - MO_DLLIMPORTAUX = 0x800, + /// MO_ARM64EC_CALLMANGLE - Operand refers to the Arm64EC-mangled version + /// of a symbol, not the original. For dllimport symbols, this means it + /// uses "__imp_aux". For other symbols, this means it uses the mangled + /// ("#" prefix for C) name. + MO_ARM64EC_CALLMANGLE = 0x800, }; } // end namespace AArch64II diff --git a/llvm/test/CodeGen/AArch64/arm64ec-dllimport.ll b/llvm/test/CodeGen/AArch64/arm64ec-dllimport.ll --- a/llvm/test/CodeGen/AArch64/arm64ec-dllimport.ll +++ b/llvm/test/CodeGen/AArch64/arm64ec-dllimport.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple=arm64ec-pc-windows-msvc < %s | FileCheck %s +; RUN: llc -mtriple=arm64ec-pc-windows-msvc -arm64ec-generate-thunks=false < %s | FileCheck %s @a = external dllimport global i32 declare dllimport void @b() diff --git a/llvm/test/CodeGen/AArch64/arm64ec-reservedregs.ll b/llvm/test/CodeGen/AArch64/arm64ec-reservedregs.ll --- a/llvm/test/CodeGen/AArch64/arm64ec-reservedregs.ll +++ b/llvm/test/CodeGen/AArch64/arm64ec-reservedregs.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple=arm64ec-pc-windows-msvc < %s | FileCheck %s +; RUN: llc -mtriple=arm64ec-pc-windows-msvc -arm64ec-generate-thunks=false < %s | FileCheck %s ; Make sure we're reserving all the registers that are supposed to be ; reserved. Integer regs x13, x15, x23, x24, x28. Float regs v16-v31. diff --git a/llvm/test/CodeGen/AArch64/arm64ec-varargs.ll b/llvm/test/CodeGen/AArch64/arm64ec-varargs.ll --- a/llvm/test/CodeGen/AArch64/arm64ec-varargs.ll +++ b/llvm/test/CodeGen/AArch64/arm64ec-varargs.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple=arm64ec-pc-windows-msvc < %s | FileCheck %s -; RUN: llc -mtriple=arm64ec-pc-windows-msvc < %s -global-isel=1 -global-isel-abort=0 | FileCheck %s +; RUN: llc -mtriple=arm64ec-pc-windows-msvc -arm64ec-generate-thunks=false < %s | FileCheck %s +; RUN: llc -mtriple=arm64ec-pc-windows-msvc -arm64ec-generate-thunks=false < %s -global-isel=1 -global-isel-abort=0 | FileCheck %s define void @varargs_callee(double %x, ...) nounwind { ; CHECK-LABEL: varargs_callee: @@ -44,7 +44,11 @@ ; CHECK-NEXT: stp xzr, x30, [sp, #24] // 8-byte Folded Spill ; CHECK-NEXT: stp x9, x8, [sp] ; CHECK-NEXT: str xzr, [sp, #16] -; CHECK-NEXT: bl varargs_callee +; CHECK-NEXT: .weak_anti_dep varargs_callee +; CHECK-NEXT: .set varargs_callee, "#varargs_callee"@WEAKREF +; CHECK-NEXT: .weak_anti_dep "#varargs_callee" +; CHECK-NEXT: .set "#varargs_callee", varargs_callee@WEAKREF +; CHECK-NEXT: bl "#varargs_callee" ; CHECK-NEXT: ldr x30, [sp, #32] // 8-byte Folded Reload ; CHECK-NEXT: add sp, sp, #48 ; CHECK-NEXT: ret @@ -81,7 +85,11 @@ ; CHECK-NEXT: str x30, [sp, #48] // 8-byte Folded Spill ; CHECK-NEXT: stp x9, x8, [sp] ; CHECK-NEXT: stp q0, q0, [sp, #16] -; CHECK-NEXT: bl varargs_many_argscallee +; CHECK-NEXT: .weak_anti_dep varargs_many_argscallee +; CHECK-NEXT: .set varargs_many_argscallee, "#varargs_many_argscallee"@WEAKREF +; CHECK-NEXT: .weak_anti_dep "#varargs_many_argscallee" +; CHECK-NEXT: .set "#varargs_many_argscallee", varargs_many_argscallee@WEAKREF +; CHECK-NEXT: bl "#varargs_many_argscallee" ; CHECK-NEXT: ldr x30, [sp, #48] // 8-byte Folded Reload ; CHECK-NEXT: add sp, sp, #64 ; CHECK-NEXT: ret diff --git a/llvm/test/CodeGen/AArch64/stack-protector-target.ll b/llvm/test/CodeGen/AArch64/stack-protector-target.ll --- a/llvm/test/CodeGen/AArch64/stack-protector-target.ll +++ b/llvm/test/CodeGen/AArch64/stack-protector-target.ll @@ -39,6 +39,6 @@ ; WINDOWS-ARM64EC: adrp x8, __security_cookie ; WINDOWS-ARM64EC: ldr x8, [x8, :lo12:__security_cookie] ; WINDOWS-ARM64EC: str x8, [sp, #8] -; WINDOWS-ARM64EC: bl _Z7CapturePi +; WINDOWS-ARM64EC: bl "#_Z7CapturePi" ; WINDOWS-ARM64EC: ldr x0, [sp, #8] -; WINDOWS-ARM64EC: bl __security_check_cookie_arm64ec +; WINDOWS-ARM64EC: bl "#__security_check_cookie_arm64ec" diff --git a/llvm/test/CodeGen/AArch64/win-alloca.ll b/llvm/test/CodeGen/AArch64/win-alloca.ll --- a/llvm/test/CodeGen/AArch64/win-alloca.ll +++ b/llvm/test/CodeGen/AArch64/win-alloca.ll @@ -21,4 +21,4 @@ ; CHECK-OPT: sub [[REG3:x[0-9]+]], sp, x15, lsl #4 ; CHECK-OPT: mov sp, [[REG3]] ; CHECK: bl func2 -; CHECK-ARM64EC: bl __chkstk_arm64ec +; CHECK-ARM64EC: bl "#__chkstk_arm64ec"