diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -334,7 +334,9 @@ MachineBasicBlock *TBB, ArrayRef Cond) const; bool substituteCmpToZero(MachineInstr &CmpInstr, unsigned SrcReg, - const MachineRegisterInfo *MRI) const; + const MachineRegisterInfo &MRI) const; + bool removeCmpToZeroOrOne(MachineInstr &CmpInstr, unsigned SrcReg, + int CmpValue, const MachineRegisterInfo &MRI) const; /// Returns an unused general-purpose register which can be used for /// constructing an outlined call if one exists. Returns 0 otherwise. diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -1463,14 +1463,16 @@ // FIXME:CmpValue has already been converted to 0 or 1 in analyzeCompare // function. assert((CmpValue == 0 || CmpValue == 1) && "CmpValue must be 0 or 1!"); - if (CmpValue != 0 || SrcReg2 != 0) + if (SrcReg2 != 0) return false; // CmpInstr is a Compare instruction if destination register is not used. if (!MRI->use_nodbg_empty(CmpInstr.getOperand(0).getReg())) return false; - return substituteCmpToZero(CmpInstr, SrcReg, MRI); + if (!CmpValue && substituteCmpToZero(CmpInstr, SrcReg, *MRI)) + return true; + return removeCmpToZeroOrOne(CmpInstr, SrcReg, CmpValue, *MRI); } /// Get opcode of S version of Instr. @@ -1524,13 +1526,44 @@ } /// Check if AArch64::NZCV should be alive in successors of MBB. -static bool areCFlagsAliveInSuccessors(MachineBasicBlock *MBB) { +static bool areCFlagsAliveInSuccessors(const MachineBasicBlock *MBB) { for (auto *BB : MBB->successors()) if (BB->isLiveIn(AArch64::NZCV)) return true; return false; } +/// \returns The condition code operand index for \p Instr if it is a branch +/// or select and -1 otherwise. +static int +findCondCodeUseOperandIdxForBranchOrSelect(const MachineInstr &Instr) { + switch (Instr.getOpcode()) { + default: + return -1; + + case AArch64::Bcc: { + int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV); + assert(Idx >= 2); + return Idx - 2; + } + + case AArch64::CSINVWr: + case AArch64::CSINVXr: + case AArch64::CSINCWr: + case AArch64::CSINCXr: + case AArch64::CSELWr: + case AArch64::CSELXr: + case AArch64::CSNEGWr: + case AArch64::CSNEGXr: + case AArch64::FCSELSrrr: + case AArch64::FCSELDrrr: { + int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV); + assert(Idx >= 1); + return Idx - 1; + } + } +} + namespace { struct UsedNZCV { @@ -1556,31 +1589,10 @@ /// Returns AArch64CC::Invalid if either the instruction does not use condition /// codes or we don't optimize CmpInstr in the presence of such instructions. static AArch64CC::CondCode findCondCodeUsedByInstr(const MachineInstr &Instr) { - switch (Instr.getOpcode()) { - default: - return AArch64CC::Invalid; - - case AArch64::Bcc: { - int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV); - assert(Idx >= 2); - return static_cast(Instr.getOperand(Idx - 2).getImm()); - } - - case AArch64::CSINVWr: - case AArch64::CSINVXr: - case AArch64::CSINCWr: - case AArch64::CSINCXr: - case AArch64::CSELWr: - case AArch64::CSELXr: - case AArch64::CSNEGWr: - case AArch64::CSNEGXr: - case AArch64::FCSELSrrr: - case AArch64::FCSELDrrr: { - int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV); - assert(Idx >= 1); - return static_cast(Instr.getOperand(Idx - 1).getImm()); - } - } + int CCIdx = findCondCodeUseOperandIdxForBranchOrSelect(Instr); + return CCIdx >= 0 ? static_cast( + Instr.getOperand(CCIdx).getImm()) + : AArch64CC::Invalid; } static UsedNZCV getUsedNZCV(AArch64CC::CondCode CC) { @@ -1627,6 +1639,41 @@ return UsedFlags; } +/// \returns Conditions flags used after \p CmpInstr in its MachineBB if they +/// are not containing C or V flags and NZCV flags are not alive in successors +/// of the same \p CmpInstr and \p MI parent. \returns None otherwise. +/// +/// Collect instructions using that flags in \p CCUseInstrs if provided. +static Optional +examineCFlagsUse(MachineInstr &MI, MachineInstr &CmpInstr, + const TargetRegisterInfo &TRI, + SmallVectorImpl *CCUseInstrs = nullptr) { + MachineBasicBlock *CmpParent = CmpInstr.getParent(); + if (MI.getParent() != CmpParent) + return None; + + if (areCFlagsAliveInSuccessors(CmpParent)) + return None; + + UsedNZCV NZCVUsedAfterCmp; + for (MachineInstr &Instr : instructionsWithoutDebug( + std::next(CmpInstr.getIterator()), CmpParent->instr_end())) { + if (Instr.readsRegister(AArch64::NZCV, &TRI)) { + AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr); + if (CC == AArch64CC::Invalid) // Unsupported conditional instruction + return None; + NZCVUsedAfterCmp |= getUsedNZCV(CC); + if (CCUseInstrs) + CCUseInstrs->push_back(&Instr); + } + if (Instr.modifiesRegister(AArch64::NZCV, &TRI)) + break; + } + if (NZCVUsedAfterCmp.C || NZCVUsedAfterCmp.V) + return None; + return NZCVUsedAfterCmp; +} + static bool isADDSRegImm(unsigned Opcode) { return Opcode == AArch64::ADDSWri || Opcode == AArch64::ADDSXri; } @@ -1646,44 +1693,21 @@ /// or if MI opcode is not the S form there must be neither defs of flags /// nor uses of flags between MI and CmpInstr. /// - and C/V flags are not used after CmpInstr -static bool canInstrSubstituteCmpInstr(MachineInstr *MI, MachineInstr *CmpInstr, - const TargetRegisterInfo *TRI) { - assert(MI); - assert(sForm(*MI) != AArch64::INSTRUCTION_LIST_END); - assert(CmpInstr); +static bool canInstrSubstituteCmpInstr(MachineInstr &MI, MachineInstr &CmpInstr, + const TargetRegisterInfo &TRI) { + assert(sForm(MI) != AArch64::INSTRUCTION_LIST_END); - const unsigned CmpOpcode = CmpInstr->getOpcode(); + const unsigned CmpOpcode = CmpInstr.getOpcode(); if (!isADDSRegImm(CmpOpcode) && !isSUBSRegImm(CmpOpcode)) return false; - if (MI->getParent() != CmpInstr->getParent()) - return false; - - if (areCFlagsAliveInSuccessors(CmpInstr->getParent())) + if (!examineCFlagsUse(MI, CmpInstr, TRI)) return false; AccessKind AccessToCheck = AK_Write; - if (sForm(*MI) != MI->getOpcode()) + if (sForm(MI) != MI.getOpcode()) AccessToCheck = AK_All; - if (areCFlagsAccessedBetweenInstrs(MI, CmpInstr, TRI, AccessToCheck)) - return false; - - UsedNZCV NZCVUsedAfterCmp; - for (const MachineInstr &Instr : - instructionsWithoutDebug(std::next(CmpInstr->getIterator()), - CmpInstr->getParent()->instr_end())) { - if (Instr.readsRegister(AArch64::NZCV, TRI)) { - AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr); - if (CC == AArch64CC::Invalid) // Unsupported conditional instruction - return false; - NZCVUsedAfterCmp |= getUsedNZCV(CC); - } - - if (Instr.modifiesRegister(AArch64::NZCV, TRI)) - break; - } - - return !NZCVUsedAfterCmp.C && !NZCVUsedAfterCmp.V; + return !areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AccessToCheck); } /// Substitute an instruction comparing to zero with another instruction @@ -1692,20 +1716,19 @@ /// Return true on success. bool AArch64InstrInfo::substituteCmpToZero( MachineInstr &CmpInstr, unsigned SrcReg, - const MachineRegisterInfo *MRI) const { - assert(MRI); + const MachineRegisterInfo &MRI) const { // Get the unique definition of SrcReg. - MachineInstr *MI = MRI->getUniqueVRegDef(SrcReg); + MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg); if (!MI) return false; - const TargetRegisterInfo *TRI = &getRegisterInfo(); + const TargetRegisterInfo &TRI = getRegisterInfo(); unsigned NewOpc = sForm(*MI); if (NewOpc == AArch64::INSTRUCTION_LIST_END) return false; - if (!canInstrSubstituteCmpInstr(MI, &CmpInstr, TRI)) + if (!canInstrSubstituteCmpInstr(*MI, CmpInstr, TRI)) return false; // Update the instruction to set NZCV. @@ -1714,7 +1737,131 @@ bool succeeded = UpdateOperandRegClass(*MI); (void)succeeded; assert(succeeded && "Some operands reg class are incompatible!"); - MI->addRegisterDefined(AArch64::NZCV, TRI); + MI->addRegisterDefined(AArch64::NZCV, &TRI); + return true; +} + +/// \returns True if \p CmpInstr can be removed. +/// +/// \p IsInvertCC is true if, after removing \p CmpInstr, condition +/// codes used in \p CCUseInstrs must be inverted. +static bool canCmpInstrBeRemoved(MachineInstr &MI, MachineInstr &CmpInstr, + int CmpValue, const TargetRegisterInfo &TRI, + SmallVectorImpl &CCUseInstrs, + bool &IsInvertCC) { + assert(CmpValue == 0 || CmpValue == 1); + + // MI is 'CSINCWr %vreg, wzr, wzr, ' or 'CSINCXr %vreg, xzr, xzr, ' + unsigned MIOpc = MI.getOpcode(); + if (MIOpc == AArch64::CSINCWr) { + if (MI.getOperand(1).getReg() != AArch64::WZR || + MI.getOperand(2).getReg() != AArch64::WZR) + return false; + } else if (MIOpc == AArch64::CSINCXr) { + if (MI.getOperand(1).getReg() != AArch64::XZR || + MI.getOperand(2).getReg() != AArch64::XZR) + return false; + } else { + return false; + } + AArch64CC::CondCode MICC = findCondCodeUsedByInstr(MI); + if (MICC == AArch64CC::Invalid) + return false; + + // NZCV needs to be defined + if (MI.findRegisterDefOperandIdx(AArch64::NZCV, true) != -1) + return false; + + // CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0' or 'SUBS %vreg, 1' + const unsigned CmpOpcode = CmpInstr.getOpcode(); + bool IsSubsRegImm = isSUBSRegImm(CmpOpcode); + if (CmpValue && !IsSubsRegImm) + return false; + if (!CmpValue && !IsSubsRegImm && !isADDSRegImm(CmpOpcode)) + return false; + + // MI conditions allowed: eq, ne, mi, pl + UsedNZCV MIUsedNZCV = getUsedNZCV(MICC); + if (MIUsedNZCV.C || MIUsedNZCV.V) + return false; + + Optional NZCVUsedAfterCmp = + examineCFlagsUse(MI, CmpInstr, TRI, &CCUseInstrs); + // Condition flags are not used in CmpInstr basic block successors and only + // Z or N flags allowed to be used after CmpInstr within its basic block + if (!NZCVUsedAfterCmp) + return false; + // Z or N flag used after CmpInstr must correspond to the flag used in MI + if ((MIUsedNZCV.Z && NZCVUsedAfterCmp->N) || + (MIUsedNZCV.N && NZCVUsedAfterCmp->Z)) + return false; + // If CmpInstr is comparison to zero MI conditions are limited to eq, ne + if (MIUsedNZCV.N && !CmpValue) + return false; + + // There must be no defs of flags between MI and CmpInstr + if (areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AK_Write)) + return false; + + // Condition code is inverted in the following cases: + // 1. MI condition is ne; CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0' + // 2. MI condition is eq, pl; CmpInstr is 'SUBS %vreg, 1' + IsInvertCC = (CmpValue && (MICC == AArch64CC::EQ || MICC == AArch64CC::PL)) || + (!CmpValue && MICC == AArch64CC::NE); + return true; +} + +/// Remove comparision in csinc-cmp sequence +/// +/// Examples: +/// 1. \code +/// csinc w9, wzr, wzr, ne +/// cmp w9, #0 +/// b.eq +/// \endcode +/// to +/// \code +/// csinc w9, wzr, wzr, ne +/// b.ne +/// \endcode +/// +/// 2. \code +/// csinc x2, xzr, xzr, mi +/// cmp x2, #1 +/// b.pl +/// \endcode +/// to +/// \code +/// csinc x2, xzr, xzr, mi +/// b.pl +/// \endcode +/// +/// \param CmpInstr comparison instruction +/// \return True when comparison removed +bool AArch64InstrInfo::removeCmpToZeroOrOne( + MachineInstr &CmpInstr, unsigned SrcReg, int CmpValue, + const MachineRegisterInfo &MRI) const { + MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg); + if (!MI) + return false; + const TargetRegisterInfo &TRI = getRegisterInfo(); + SmallVector CCUseInstrs; + bool IsInvertCC = false; + if (!canCmpInstrBeRemoved(*MI, CmpInstr, CmpValue, TRI, CCUseInstrs, + IsInvertCC)) + return false; + // Make transformation + CmpInstr.eraseFromParent(); + if (IsInvertCC) + // Invert condition codes in CmpInstr CC users + for (MachineInstr *CCUseInstr : CCUseInstrs) { + int Idx = findCondCodeUseOperandIdxForBranchOrSelect(*CCUseInstr); + assert(Idx >= 0 && "Unexpected instruction using CC."); + MachineOperand &CCOperand = CCUseInstr->getOperand(Idx); + AArch64CC::CondCode CCUse = AArch64CC::getInvertedCondCode( + static_cast(CCOperand.getImm())); + CCOperand.setImm(CCUse); + } return true; } diff --git a/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir b/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir --- a/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir +++ b/llvm/test/CodeGen/AArch64/csinc-cmp-removal.mir @@ -12,7 +12,6 @@ ; CHECK: [[DEF:%[0-9]+]]:gpr64 = IMPLICIT_DEF ; CHECK: [[SUBSXrr:%[0-9]+]]:gpr64 = SUBSXrr killed [[DEF]], [[COPY]], implicit-def $nzcv ; CHECK: [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 1, implicit $nzcv - ; CHECK: [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri killed [[CSINCWr]], 1, 0, implicit-def $nzcv ; CHECK: Bcc 1, %bb.2, implicit $nzcv ; CHECK: B %bb.1 ; CHECK: bb.1: @@ -51,8 +50,7 @@ ; CHECK: [[DEF:%[0-9]+]]:gpr64 = IMPLICIT_DEF ; CHECK: [[SUBSXrr:%[0-9]+]]:gpr64 = SUBSXrr killed [[DEF]], [[COPY]], implicit-def $nzcv ; CHECK: [[CSINCXr:%[0-9]+]]:gpr64common = CSINCXr $xzr, $xzr, 1, implicit $nzcv - ; CHECK: [[SUBSXri:%[0-9]+]]:gpr64 = SUBSXri killed [[CSINCXr]], 0, 0, implicit-def $nzcv - ; CHECK: Bcc 0, %bb.2, implicit $nzcv + ; CHECK: Bcc 1, %bb.2, implicit $nzcv ; CHECK: B %bb.1 ; CHECK: bb.1: ; CHECK: successors: %bb.2(0x80000000) @@ -155,8 +153,7 @@ ; CHECK: successors: %bb.1(0x40000000), %bb.2(0x40000000) ; CHECK: liveins: $nzcv ; CHECK: [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 1, implicit $nzcv - ; CHECK: [[ADDSWri:%[0-9]+]]:gpr32 = ADDSWri killed [[CSINCWr]], 0, 0, implicit-def $nzcv - ; CHECK: Bcc 1, %bb.2, implicit $nzcv + ; CHECK: Bcc 0, %bb.2, implicit $nzcv ; CHECK: B %bb.1 ; CHECK: bb.1: ; CHECK: successors: %bb.2(0x80000000) @@ -254,8 +251,7 @@ ; CHECK: successors: %bb.1(0x40000000), %bb.2(0x40000000) ; CHECK: liveins: $nzcv ; CHECK: [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 5, implicit $nzcv - ; CHECK: [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri killed [[CSINCWr]], 1, 0, implicit-def $nzcv - ; CHECK: Bcc 4, %bb.2, implicit $nzcv + ; CHECK: Bcc 5, %bb.2, implicit $nzcv ; CHECK: B %bb.1 ; CHECK: bb.1: ; CHECK: successors: %bb.2(0x80000000) diff --git a/llvm/test/CodeGen/AArch64/f16-instructions.ll b/llvm/test/CodeGen/AArch64/f16-instructions.ll --- a/llvm/test/CodeGen/AArch64/f16-instructions.ll +++ b/llvm/test/CodeGen/AArch64/f16-instructions.ll @@ -189,8 +189,6 @@ ; CHECK-CVT-DAG: fcvt s1, h1 ; CHECK-CVT-DAG: fcvt s0, h0 ; CHECK-CVT-DAG: fcmp s2, s3 -; CHECK-CVT-DAG: cset [[CC:w[0-9]+]], ne -; CHECK-CVT-DAG: cmp [[CC]], #0 ; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne ; CHECK-CVT-NEXT: fcvt h0, s0 ; CHECK-CVT-NEXT: ret @@ -228,8 +226,6 @@ ; CHECK-CVT-DAG: fcvt s0, h0 ; CHECK-CVT-DAG: fcvt s1, h1 ; CHECK-CVT-DAG: fcmp s2, s3 -; CHECK-CVT-DAG: cset w8, ne -; CHECK-CVT-NEXT: cmp w8, #0 ; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne ; CHECK-CVT-NEXT: fcvt h0, s0 ; CHECK-CVT-NEXT: ret