diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -5037,6 +5037,7 @@ unsigned NumArgs = Ins.size(); Function::const_arg_iterator CurOrigArg = MF.getFunction().arg_begin(); unsigned CurArgIdx = 0; + SmallVector BoolArgIdx; for (unsigned i = 0; i != NumArgs; ++i) { MVT ValVT = Ins[i].VT; if (Ins[i].isOrigArg()) { @@ -5052,6 +5053,9 @@ ValVT = MVT::i8; else if (ActualMVT == MVT::i16) ValVT = MVT::i16; + + if (ActualMVT == MVT::i1) + BoolArgIdx.push_back(i); } bool UseVarArgCC = false; if (IsWin64) @@ -5290,6 +5294,18 @@ } } + // Record registers that have i1 as the actual type: they are + // zero-extended to 8-bits by the caller, so when we forward them to + // callees, we don't have to zext them again. + for (unsigned I : BoolArgIdx) { + SDValue Val = InVals[I]; + if (Val->getOpcode() == ISD::CopyFromReg) { + if (auto *RegVal = dyn_cast(Val->getOperand(1))) { + FuncInfo->getBoolRegParms().push_back(RegVal->getReg()); + } + } + } + unsigned StackArgSize = CCInfo.getNextStackOffset(); bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt; if (DoesCalleeRestoreStack(CallConv, TailCallOpt)) { @@ -5670,6 +5686,28 @@ CallCC == CallingConv::Tail || CallCC == CallingConv::SwiftTail; } +// If the register is aleary zero-extended by the caller, we don't +// have to do this again. +static bool shouldZExtBoolArg(SDValue Arg, + const AArch64FunctionInfo &FuncInfo) { + if (Arg->getOpcode() != ISD::CopyFromReg) + return true; + + auto *ArgVal = dyn_cast(Arg->getOperand(1)); + if (!ArgVal) + return true; + + Register Reg = ArgVal->getReg(); + for (Register BoolReg : FuncInfo.getBoolRegParms()) { + if (Reg == BoolReg) { + // The register is an i1 argument of this function. No need to + // zero-extend it for the callee. + return false; + } + } + return true; +} + /// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain, /// and add input and output parameter nodes. SDValue @@ -5868,8 +5906,10 @@ case CCValAssign::AExt: if (Outs[i].ArgVT == MVT::i1) { // AAPCS requires i1 to be zero-extended to 8-bits by the caller. - Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg); - Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i8, Arg); + if (shouldZExtBoolArg(Arg, *FuncInfo)) { + Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg); + Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i8, Arg); + } } Arg = DAG.getNode(ISD::ANY_EXTEND, DL, VA.getLocVT(), Arg); break; diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -134,6 +134,10 @@ /// that must be forwarded to every musttail call. SmallVector ForwardedMustTailRegParms; + /// Parameters that originally have i1 type - they are zero-extended + /// to 8-bits by the caller. + SmallVector BoolRegParms; + /// FrameIndex for the tagged base pointer. Optional TaggedBasePointerIndex; @@ -372,6 +376,12 @@ return ForwardedMustTailRegParms; } + SmallVectorImpl &getBoolRegParms() { return BoolRegParms; } + + const SmallVectorImpl &getBoolRegParms() const { + return BoolRegParms; + } + Optional getTaggedBasePointerIndex() const { return TaggedBasePointerIndex; } diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp --- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp @@ -588,6 +588,18 @@ // in this function later. FuncInfo->setBytesInStackArgArea(StackOffset); + // Record registers that have i1 as the actual type: they are + // zero-extended to 8-bits by the caller, so when we forward them to + // callees, we don't have to zext them again. + SmallVectorImpl &BoolRegs = FuncInfo->getBoolRegParms(); + for (const ArgInfo &AI : SplitArgs) { + if (!AI.Ty->isIntegerTy(1)) + continue; + assert(AI.Regs.size() == 1 && + "Unexpected number of registers used for i1 argument"); + BoolRegs.push_back(AI.Regs[0]); + } + auto &Subtarget = MF.getSubtarget(); if (Subtarget.hasCustomCallingConv()) Subtarget.getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF); @@ -1030,6 +1042,34 @@ return true; } +// If the register is aleary zero-extended by the caller, we don't +// have to do this again. +static bool shouldZExtBoolArg(const CallLowering::ArgInfo &AI, + const AArch64FunctionInfo &FuncInfo, + const MachineRegisterInfo &MRI) { + assert(AI.Regs.size() == 1 && + "Unexpected number of registers used for i1 argument"); + + MachineInstr *ArgDef = MRI.getVRegDef(AI.Regs[0]); + if (!ArgDef || ArgDef->getOpcode() != TargetOpcode::G_TRUNC) + return true; + + MachineOperand &Op = ArgDef->getOperand(1); + if (!Op.isReg()) + return true; + + Register OrigReg = Op.getReg(); + for (const Register &BoolReg : FuncInfo.getBoolRegParms()) { + if (BoolReg == OrigReg) { + // The register is an i1 argument of this function. No need to + // zero-extend it for the callee. + return false; + } + } + + return true; +} + bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info) const { MachineFunction &MF = MIRBuilder.getMF(); @@ -1037,13 +1077,16 @@ MachineRegisterInfo &MRI = MF.getRegInfo(); auto &DL = F.getParent()->getDataLayout(); const AArch64TargetLowering &TLI = *getTLI(); + AArch64FunctionInfo *FuncInfo = MF.getInfo(); SmallVector OutArgs; for (auto &OrigArg : Info.OrigArgs) { splitToValueTypes(OrigArg, OutArgs, DL, Info.CallConv); // AAPCS requires that we zero-extend i1 to 8 bits by the caller. - if (OrigArg.Ty->isIntegerTy(1)) - OutArgs.back().Flags[0].setZExt(); + if (OrigArg.Ty->isIntegerTy(1)) { + if (shouldZExtBoolArg(OrigArg, *FuncInfo, MRI)) + OutArgs.back().Flags[0].setZExt(); + } } SmallVector InArgs; diff --git a/llvm/test/CodeGen/AArch64/i1-contents.ll b/llvm/test/CodeGen/AArch64/i1-contents.ll --- a/llvm/test/CodeGen/AArch64/i1-contents.ll +++ b/llvm/test/CodeGen/AArch64/i1-contents.ll @@ -1,4 +1,5 @@ ; RUN: llc -mtriple=aarch64-linux-gnu -o - %s | FileCheck %s +; RUN: llc -global-isel -mtriple=aarch64-linux-gnu -o - %s | FileCheck %s %big = type i32 @var = dso_local global %big 0 @@ -48,6 +49,13 @@ ret void } +define dso_local void @forward_i1_arg(i1 %in) { +; CHECK-LABEL: forward_i1_arg: +; CHECK-NOT: and +; CHECK: bl consume_i1_arg + call void @consume_i1_arg(i1 %in) + ret void +} ;define zeroext i1 @foo(i8 %in) { ; %val = trunc i8 %in to i1