diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h --- a/llvm/lib/Target/AArch64/AArch64.h +++ b/llvm/lib/Target/AArch64/AArch64.h @@ -68,6 +68,7 @@ FunctionPass *createAArch64PostSelectOptimize(); FunctionPass *createAArch64StackTaggingPass(bool IsOptNone); FunctionPass *createAArch64StackTaggingPreRAPass(); +FunctionPass *createAArch64Arm64ECCallLoweringPass(); void initializeAArch64A53Fix835769Pass(PassRegistry&); void initializeAArch64A57FPLoadBalancingPass(PassRegistry&); @@ -101,6 +102,7 @@ void initializeSVEIntrinsicOptsPass(PassRegistry&); void initializeAArch64StackTaggingPass(PassRegistry&); void initializeAArch64StackTaggingPreRAPass(PassRegistry&); +void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &); } // end namespace llvm #endif diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp @@ -0,0 +1,280 @@ +//===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the IR transform to lower external or indirect calls for +/// the ARM64EC calling convention. Such calls must go through the runtime, so +/// we can translate the calling convention for calls into the emulator. +/// +/// This subsumes Control Flow Guard handling. +/// +//===----------------------------------------------------------------------===// + +#include "AArch64.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Triple.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" + +using namespace llvm; + +using OperandBundleDef = OperandBundleDefT; + +#define DEBUG_TYPE "arm64eccalllowering" + +STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered"); + +namespace { + +class AArch64Arm64ECCallLowering : public FunctionPass { +public: + static char ID; + AArch64Arm64ECCallLowering() : FunctionPass(ID) { + initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry()); + } + + Function *buildExitThunk(CallBase *CB); + void lowerCall(CallBase *CB); + bool doInitialization(Module &M) override; + bool runOnFunction(Function &F) override; + +private: + int cfguard_module_flag = 0; + FunctionType *GuardFnType = nullptr; + PointerType *GuardFnPtrType = nullptr; + Constant *GuardFnCFGlobal = nullptr; + Constant *GuardFnGlobal = nullptr; + Module *M = nullptr; +}; + +} // end anonymous namespace + +Function *AArch64Arm64ECCallLowering::buildExitThunk(CallBase *CB) { + Type *RetTy = CB->getFunctionType()->getReturnType(); + SmallVector DefArgTypes; + // The first argument to a thunk is the called function, stored in x9. + // (Normally, we won't explicitly refer to this in the assembly; it just + // gets passed on by the call.) + DefArgTypes.push_back(Type::getInt8PtrTy(M->getContext())); + for (unsigned i = 0; i < CB->arg_size(); ++i) { + DefArgTypes.push_back(CB->getArgOperand(i)->getType()); + } + FunctionType *Ty = FunctionType::get(RetTy, DefArgTypes, false); + Function *F = + Function::Create(Ty, GlobalValue::InternalLinkage, 0, "thunk", M); + F->setCallingConv(CallingConv::ARM64EC_Thunk_Native); + // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) + F->addFnAttr("frame-pointer", "all"); + // Only copy sret from the first argument. For C++ instance methods, clang can + // stick an sret marking on a later argument, but it doesn't actually affect + // the ABI, so we can omit it. This avoids triggering a verifier assertion. + if (CB->arg_size() > 0) { + auto Attr = CB->getParamAttr(0, Attribute::StructRet); + if (Attr.isValid()) + F->addParamAttr(1, Attr); + } + // FIXME: Copy anything other than sret? Shouldn't be necessary for normal + // C ABI, but might show up in other cases. + BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F); + IRBuilder<> IRB(BB); + PointerType *DispatchPtrTy = FunctionType::get(IRB.getVoidTy(), false)->getPointerTo(0); + Value *CalleePtr = M->getOrInsertGlobal( + "__os_arm64x_dispatch_call_no_redirect", DispatchPtrTy); + Value *Callee = IRB.CreateLoad(DispatchPtrTy, CalleePtr); + auto &DL = M->getDataLayout(); + SmallVector Args; + SmallVector ArgTypes; + + // Pass the called function in x9. + Args.push_back(F->arg_begin()); + ArgTypes.push_back(Args.back()->getType()); + + Type *X64RetType = RetTy; + if (RetTy->isArrayTy() || RetTy->isStructTy()) { + // If the return type is an array or struct, translate it. Values of size + // 8 or less go into RAX; bigger values go into memory, and we pass a + // pointer. + if (DL.getTypeStoreSize(RetTy) > 8) { + Args.push_back(IRB.CreateAlloca(RetTy)); + ArgTypes.push_back(Args.back()->getType()); + X64RetType = IRB.getVoidTy(); + } else { + X64RetType = IRB.getIntNTy(DL.getTypeStoreSizeInBits(RetTy)); + } + } + + for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) { + // Translate arguments from AArch64 calling convention to x86 calling + // convention. + // + // For simple types, we don't need to do any translation: they're + // represented the same way. (Implicit sign extension is not part of + // either convention.) + // + // The big thing we have to worry about is struct types... but + // fortunately AArch64 clang is pretty friendly here: the cases that need + // translation are always passed as a struct or array. (If we run into + // some cases where this doesn't work, we can teach clang to mark it up + // with an attribute.) + // + // The first argument is the called function, stored in x9. + if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy()) { + Value *Mem = IRB.CreateAlloca(Arg.getType()); + IRB.CreateStore(&Arg, Mem); + if (DL.getTypeStoreSize(Arg.getType()) <= 8) + Args.push_back(IRB.CreateLoad( + IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType())), Mem)); + else + Args.push_back(Mem); + } else { + Args.push_back(&Arg); + } + ArgTypes.push_back(Args.back()->getType()); + } + // FIXME: Transfer necessary attributes? sret? anything else? + // FIXME: Try to share thunks. This probably involves simplifying the + // argument types (translating all integers/pointers to i64, etc.) + auto *CallTy = FunctionType::get(X64RetType, ArgTypes, false); + + Callee = IRB.CreateBitCast(Callee, CallTy->getPointerTo(0)); + CallInst *Call = IRB.CreateCall(CallTy, Callee, Args); + Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64); + + Value *RetVal = Call; + if (RetTy->isArrayTy() || RetTy->isStructTy()) { + // If we rewrote the return type earlier, convert the return value to + // the proper type. + if (DL.getTypeStoreSize(RetTy) > 8) { + RetVal = IRB.CreateLoad(RetTy, Args[1]); + } else { + Value *CastAlloca = IRB.CreateAlloca(RetTy); + IRB.CreateStore(Call, + IRB.CreateBitCast(CastAlloca, Call->getType()->getPointerTo(0))); + RetVal = IRB.CreateLoad(RetTy, CastAlloca); + } + } + + if (RetTy->isVoidTy()) + IRB.CreateRetVoid(); + else + IRB.CreateRet(RetVal); + return F; +} + +void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) { + assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() && + "Only applicable for Windows targets"); + + IRBuilder<> B(CB); + Value *CalledOperand = CB->getCalledOperand(); + + // If the indirect call is called within catchpad or cleanuppad, + // we need to copy "funclet" bundle of the call. + SmallVector Bundles; + if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet)) + Bundles.push_back(OperandBundleDef(*Bundle)); + + // Load the global symbol as a pointer to the check function. + Value *GuardFn; + if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf")) + GuardFn = GuardFnCFGlobal; + else + GuardFn = GuardFnGlobal; + LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn); + + // Create new call instruction. The CFGuard check should always be a call, + // even if the original CallBase is an Invoke or CallBr instruction. + Function *Thunk = buildExitThunk(CB); + CallInst *GuardCheck = + B.CreateCall(GuardFnType, GuardCheckLoad, + {B.CreateBitCast(CalledOperand, B.getInt8PtrTy()), + B.CreateBitCast(Thunk, B.getInt8PtrTy())}, + Bundles); + + // Ensure that the first argument is passed in the correct register + // (e.g. ECX on 32-bit X86 targets). + GuardCheck->setCallingConv(CallingConv::CFGuard_Check); + + Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType()); + CB->setCalledOperand(GuardRetVal); +} + +bool AArch64Arm64ECCallLowering::doInitialization(Module &Mod) { + M = &Mod; + + // Check if this module has the cfguard flag and read its value. + if (auto *MD = + mdconst::extract_or_null(M->getModuleFlag("cfguard"))) + cfguard_module_flag = MD->getZExtValue(); + + Type *Int8Ptr = Type::getInt8PtrTy(M->getContext()); + GuardFnType = FunctionType::get(Int8Ptr, {Int8Ptr, Int8Ptr}, false); + GuardFnPtrType = PointerType::get(GuardFnType, 0); + GuardFnCFGlobal = + M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType); + GuardFnGlobal = + M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType); + return true; +} + +bool AArch64Arm64ECCallLowering::runOnFunction(Function &F) { + SmallVector IndirectCalls; + + // Iterate over the instructions to find all indirect call/invoke/callbr + // instructions. Make a separate list of pointers to indirect + // call/invoke/callbr instructions because the original instructions will be + // deleted as the checks are added. + for (BasicBlock &BB : F.getBasicBlockList()) { + for (Instruction &I : BB.getInstList()) { + auto *CB = dyn_cast(&I); + if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 || + CB->isInlineAsm()) + continue; + + // We need to instrument any call that isn't directly calling an + // ARM64 function. + // + // FIXME: isDSOLocal() doesn't do what we want; even if the symbol is + // technically local, automatic dllimport means the function it refers + // to might not be. + // + // FIXME: If a function is dllimport, we can just mark up the symbol + // using hybmp$x, and everything just works. If the function is not + // marked dllimport, we can still mark up the symbol, but we somehow + // need an extra stub to compute the correct callee. Not really + // understanding how this works. + if (Function *F = CB->getCalledFunction()) { + if (F->isDSOLocal() || F->isIntrinsic()) + continue; + } + + IndirectCalls.push_back(CB); + ++Arm64ECCallsLowered; + } + } + + if (IndirectCalls.empty()) + return false; + + for (CallBase *CB : IndirectCalls) + lowerCall(CB); + + return true; +} + +char AArch64Arm64ECCallLowering::ID = 0; +INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering", + "AArch64Arm64ECCallLowering", false, false) + +FunctionPass *llvm::createAArch64Arm64ECCallLoweringPass() { + return new AArch64Arm64ECCallLowering; +} diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -229,6 +229,7 @@ initializeAArch64StackTaggingPass(*PR); initializeAArch64StackTaggingPreRAPass(*PR); initializeAArch64LowerHomogeneousPrologEpilogPass(*PR); + initializeAArch64StackTaggingPreRAPass(*PR); } //===----------------------------------------------------------------------===// @@ -588,8 +589,12 @@ } // Add Control Flow Guard checks. - if (TM->getTargetTriple().isOSWindows()) - addPass(createCFGuardCheckPass()); + if (TM->getTargetTriple().isOSWindows()) { + if (TM->getTargetTriple().isWindowsArm64EC()) + addPass(createAArch64Arm64ECCallLoweringPass()); + else + addPass(createCFGuardCheckPass()); + } if (TM->Options.JMCInstrument) addPass(createJMCInstrumenterPass()); diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt --- a/llvm/lib/Target/AArch64/CMakeLists.txt +++ b/llvm/lib/Target/AArch64/CMakeLists.txt @@ -84,6 +84,7 @@ AArch64TargetTransformInfo.cpp SVEIntrinsicOpts.cpp AArch64SIMDInstrOpt.cpp + AArch64Arm64ECCallLowering.cpp DEPENDS intrinsics_gen diff --git a/llvm/test/CodeGen/AArch64/arm64ec-cfg.ll b/llvm/test/CodeGen/AArch64/arm64ec-cfg.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/arm64ec-cfg.ll @@ -0,0 +1,166 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --include-generated-funcs +; RUN: llc -mtriple=aarch64-pc-windows-msvc_arm64ec < %s | FileCheck %s + +define void @f(ptr %g) { +entry: + call void %g() + ret void +} + +define void @f2(ptr %g) { +entry: + call void %g(i32 1, i32 2, i32 3, i32 4, i32 5) + ret void +} + +define void @f3(ptr %g) { +entry: + call void %g([4 x float] zeroinitializer) + ret void +} + +; CHECK-LABEL: f: +; CHECK: .seh_proc f +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: .seh_save_reg_x x30, 16 +; CHECK-NEXT: .seh_endprologue +; CHECK-NEXT: adrp x8, __os_arm64x_check_icall +; CHECK-NEXT: adrp x10, thunk +; CHECK-NEXT: add x10, x10, :lo12:thunk +; CHECK-NEXT: mov x11, x0 +; CHECK-NEXT: ldr x8, [x8, :lo12:__os_arm64x_check_icall] +; CHECK-NEXT: blr x8 +; CHECK-NEXT: blr x11 +; CHECK-NEXT: .seh_startepilogue +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: .seh_save_reg_x x30, 16 +; CHECK-NEXT: .seh_endepilogue +; CHECK-NEXT: ret +; CHECK-NEXT: .seh_endfunclet +; CHECK-NEXT: .seh_endproc +; +; CHECK-LABEL: f2: +; CHECK: .seh_proc f2 +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: .seh_save_reg_x x30, 16 +; CHECK-NEXT: .seh_endprologue +; CHECK-NEXT: adrp x8, __os_arm64x_check_icall +; CHECK-NEXT: adrp x10, thunk.1 +; CHECK-NEXT: add x10, x10, :lo12:thunk.1 +; CHECK-NEXT: mov x11, x0 +; CHECK-NEXT: ldr x8, [x8, :lo12:__os_arm64x_check_icall] +; CHECK-NEXT: blr x8 +; CHECK-NEXT: mov w0, #1 +; CHECK-NEXT: mov w1, #2 +; CHECK-NEXT: mov w2, #3 +; CHECK-NEXT: mov w3, #4 +; CHECK-NEXT: mov w4, #5 +; CHECK-NEXT: blr x11 +; CHECK-NEXT: .seh_startepilogue +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: .seh_save_reg_x x30, 16 +; CHECK-NEXT: .seh_endepilogue +; CHECK-NEXT: ret +; CHECK-NEXT: .seh_endfunclet +; CHECK-NEXT: .seh_endproc +; +; CHECK-LABEL: f3: +; CHECK: .seh_proc f3 +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: .seh_save_reg_x x30, 16 +; CHECK-NEXT: .seh_endprologue +; CHECK-NEXT: adrp x8, __os_arm64x_check_icall +; CHECK-NEXT: adrp x10, thunk.2 +; CHECK-NEXT: add x10, x10, :lo12:thunk.2 +; CHECK-NEXT: mov x11, x0 +; CHECK-NEXT: ldr x8, [x8, :lo12:__os_arm64x_check_icall] +; CHECK-NEXT: blr x8 +; CHECK-NEXT: movi d0, #0000000000000000 +; CHECK-NEXT: movi d1, #0000000000000000 +; CHECK-NEXT: movi d2, #0000000000000000 +; CHECK-NEXT: movi d3, #0000000000000000 +; CHECK-NEXT: blr x11 +; CHECK-NEXT: .seh_startepilogue +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: .seh_save_reg_x x30, 16 +; CHECK-NEXT: .seh_endepilogue +; CHECK-NEXT: ret +; CHECK-NEXT: .seh_endfunclet +; CHECK-NEXT: .seh_endproc +; +; CHECK-LABEL: thunk: +; CHECK: .seh_proc thunk +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: sub sp, sp, #48 +; CHECK-NEXT: .seh_stackalloc 48 +; CHECK-NEXT: stp x29, x30, [sp, #32] // 16-byte Folded Spill +; CHECK-NEXT: .seh_save_fplr 32 +; CHECK-NEXT: add x29, sp, #32 +; CHECK-NEXT: .seh_add_fp 32 +; CHECK-NEXT: .seh_endprologue +; CHECK-NEXT: adrp x8, __os_arm64x_dispatch_call_no_redirect +; CHECK-NEXT: ldr x8, [x8, :lo12:__os_arm64x_dispatch_call_no_redirect] +; CHECK-NEXT: blr x8 +; CHECK-NEXT: .seh_startepilogue +; CHECK-NEXT: ldp x29, x30, [sp, #32] // 16-byte Folded Reload +; CHECK-NEXT: .seh_save_fplr 32 +; CHECK-NEXT: add sp, sp, #48 +; CHECK-NEXT: .seh_stackalloc 48 +; CHECK-NEXT: .seh_endepilogue +; CHECK-NEXT: ret +; CHECK-NEXT: .seh_endfunclet +; CHECK-NEXT: .seh_endproc +; +; CHECK-LABEL: thunk.1: +; CHECK: .seh_proc thunk.1 +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: sub sp, sp, #64 +; CHECK-NEXT: .seh_stackalloc 64 +; CHECK-NEXT: stp x29, x30, [sp, #48] // 16-byte Folded Spill +; CHECK-NEXT: .seh_save_fplr 48 +; CHECK-NEXT: add x29, sp, #48 +; CHECK-NEXT: .seh_add_fp 48 +; CHECK-NEXT: .seh_endprologue +; CHECK-NEXT: adrp x8, __os_arm64x_dispatch_call_no_redirect +; CHECK-NEXT: str w4, [sp, #32] +; CHECK-NEXT: ldr x8, [x8, :lo12:__os_arm64x_dispatch_call_no_redirect] +; CHECK-NEXT: blr x8 +; CHECK-NEXT: .seh_startepilogue +; CHECK-NEXT: ldp x29, x30, [sp, #48] // 16-byte Folded Reload +; CHECK-NEXT: .seh_save_fplr 48 +; CHECK-NEXT: add sp, sp, #64 +; CHECK-NEXT: .seh_stackalloc 64 +; CHECK-NEXT: .seh_endepilogue +; CHECK-NEXT: ret +; CHECK-NEXT: .seh_endfunclet +; CHECK-NEXT: .seh_endproc +; +; CHECK-LABEL: thunk.2: +; CHECK: .seh_proc thunk.2 +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: sub sp, sp, #64 +; CHECK-NEXT: .seh_stackalloc 64 +; CHECK-NEXT: stp x29, x30, [sp, #48] // 16-byte Folded Spill +; CHECK-NEXT: .seh_save_fplr 48 +; CHECK-NEXT: add x29, sp, #48 +; CHECK-NEXT: .seh_add_fp 48 +; CHECK-NEXT: .seh_endprologue +; CHECK-NEXT: adrp x8, __os_arm64x_dispatch_call_no_redirect +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: stp s1, s2, [x29, #-12] +; CHECK-NEXT: stur s0, [x29, #-16] +; CHECK-NEXT: ldr x8, [x8, :lo12:__os_arm64x_dispatch_call_no_redirect] +; CHECK-NEXT: stur s3, [x29, #-4] +; CHECK-NEXT: blr x8 +; CHECK-NEXT: .seh_startepilogue +; CHECK-NEXT: ldp x29, x30, [sp, #48] // 16-byte Folded Reload +; CHECK-NEXT: .seh_save_fplr 48 +; CHECK-NEXT: add sp, sp, #64 +; CHECK-NEXT: .seh_stackalloc 64 +; CHECK-NEXT: .seh_endepilogue +; CHECK-NEXT: ret +; CHECK-NEXT: .seh_endfunclet +; CHECK-NEXT: .seh_endproc