Index: llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp =================================================================== --- llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -102,11 +102,19 @@ bool selectVaStartDarwin(MachineInstr &I, MachineFunction &MF, MachineRegisterInfo &MRI) const; - bool tryOptAndIntoCompareBranch(MachineInstr *LHS, - int64_t CmpConstant, - const CmpInst::Predicate &Pred, + ///@{ + /// Helper functions for selectCompareBranch. + bool selectCompareBranchFedByFCmp(MachineInstr &I, MachineInstr &FCmp, + MachineIRBuilder &MIB) const; + bool selectCompareBranchFedByICmp(MachineInstr &I, MachineInstr &ICmp, + MachineIRBuilder &MIB) const; + bool tryOptCompareBranchFedByICmp(MachineInstr &I, MachineInstr &ICmp, + MachineIRBuilder &MIB) const; + bool tryOptAndIntoCompareBranch(MachineInstr &AndInst, bool Invert, MachineBasicBlock *DstMBB, MachineIRBuilder &MIB) const; + ///@} + bool selectCompareBranch(MachineInstr &I, MachineFunction &MF, MachineRegisterInfo &MRI) const; @@ -1368,8 +1376,9 @@ } bool AArch64InstructionSelector::tryOptAndIntoCompareBranch( - MachineInstr *AndInst, int64_t CmpConstant, const CmpInst::Predicate &Pred, - MachineBasicBlock *DstMBB, MachineIRBuilder &MIB) const { + MachineInstr &AndInst, bool Invert, MachineBasicBlock *DstMBB, + MachineIRBuilder &MIB) const { + assert(AndInst.getOpcode() == TargetOpcode::G_AND && "Expected G_AND only?"); // Given something like this: // // %x = ...Something... @@ -1387,31 +1396,17 @@ // // TBNZ %x %bb.3 // - if (!AndInst || AndInst->getOpcode() != TargetOpcode::G_AND) - return false; - - // Need to be comparing against 0 to fold. - if (CmpConstant != 0) - return false; - - MachineRegisterInfo &MRI = *MIB.getMRI(); - - // Only support EQ and NE. If we have LT, then it *is* possible to fold, but - // we don't want to do this. When we have an AND and LT, we need a TST/ANDS, - // so folding would be redundant. - assert(ICmpInst::isEquality(Pred) && "Expected only eq/ne?"); // Check if the AND has a constant on its RHS which we can use as a mask. // If it's a power of 2, then it's the same as checking a specific bit. // (e.g, ANDing with 8 == ANDing with 000...100 == testing if bit 3 is set) - auto MaybeBit = - getConstantVRegValWithLookThrough(AndInst->getOperand(2).getReg(), MRI); + auto MaybeBit = getConstantVRegValWithLookThrough( + AndInst.getOperand(2).getReg(), *MIB.getMRI()); if (!MaybeBit || !isPowerOf2_64(MaybeBit->Value)) return false; uint64_t Bit = Log2_64(static_cast(MaybeBit->Value)); - Register TestReg = AndInst->getOperand(1).getReg(); - bool Invert = Pred == CmpInst::Predicate::ICMP_NE; + Register TestReg = AndInst.getOperand(1).getReg(); // Emit a TB(N)Z. emitTestBit(TestReg, Bit, Invert, DstMBB, MIB); @@ -1439,47 +1434,54 @@ return &*BranchMI; } -bool AArch64InstructionSelector::selectCompareBranch( - MachineInstr &I, MachineFunction &MF, MachineRegisterInfo &MRI) const { - - const Register CondReg = I.getOperand(0).getReg(); +bool AArch64InstructionSelector::selectCompareBranchFedByFCmp( + MachineInstr &I, MachineInstr &FCmp, MachineIRBuilder &MIB) const { + assert(FCmp.getOpcode() == TargetOpcode::G_FCMP); + assert(I.getOpcode() == TargetOpcode::G_BRCOND); + // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't + // totally clean. Some of them require two branches to implement. + emitFPCompare(FCmp.getOperand(2).getReg(), FCmp.getOperand(3).getReg(), MIB); + AArch64CC::CondCode CC1, CC2; + changeFCMPPredToAArch64CC( + static_cast(FCmp.getOperand(1).getPredicate()), CC1, + CC2); MachineBasicBlock *DestMBB = I.getOperand(1).getMBB(); - MachineInstr *CCMI = MRI.getVRegDef(CondReg); - if (CCMI->getOpcode() == TargetOpcode::G_TRUNC) - CCMI = MRI.getVRegDef(CCMI->getOperand(1).getReg()); + MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC1).addMBB(DestMBB); + if (CC2 != AArch64CC::AL) + MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC2).addMBB(DestMBB); + I.eraseFromParent(); + return true; +} - unsigned CCMIOpc = CCMI->getOpcode(); - if (CCMIOpc != TargetOpcode::G_ICMP && CCMIOpc != TargetOpcode::G_FCMP) +bool AArch64InstructionSelector::tryOptCompareBranchFedByICmp( + MachineInstr &I, MachineInstr &ICmp, MachineIRBuilder &MIB) const { + assert(ICmp.getOpcode() == TargetOpcode::G_ICMP); + assert(I.getOpcode() == TargetOpcode::G_BRCOND); + // Attempt to optimize the G_BRCOND + G_ICMP into a TB(N)Z/CB(N)Z. + // + // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z + // instructions will not be produced, as they are conditional branch + // instructions that do not set flags. + if (!ProduceNonFlagSettingCondBr) return false; - MachineIRBuilder MIB(I); - Register LHS = CCMI->getOperand(2).getReg(); - Register RHS = CCMI->getOperand(3).getReg(); + MachineRegisterInfo &MRI = *MIB.getMRI(); + MachineBasicBlock *DestMBB = I.getOperand(1).getMBB(); auto Pred = - static_cast(CCMI->getOperand(1).getPredicate()); - - if (CCMIOpc == TargetOpcode::G_FCMP) { - // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't - // totally clean. Some of them require two branches to implement. - emitFPCompare(LHS, RHS, MIB); - AArch64CC::CondCode CC1, CC2; - changeFCMPPredToAArch64CC(Pred, CC1, CC2); - MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC1).addMBB(DestMBB); - if (CC2 != AArch64CC::AL) - MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC2).addMBB(DestMBB); - I.eraseFromParent(); - return true; - } + static_cast(ICmp.getOperand(1).getPredicate()); + Register LHS = ICmp.getOperand(2).getReg(); + Register RHS = ICmp.getOperand(3).getReg(); + // We're allowed to emit a TB(N)Z/CB(N)Z. Try to do that. auto VRegAndVal = getConstantVRegValWithLookThrough(RHS, MRI); - MachineInstr *LHSMI = getDefIgnoringCopies(LHS, MRI); + MachineInstr *AndInst = getOpcodeDef(TargetOpcode::G_AND, LHS, MRI); // When we can emit a TB(N)Z, prefer that. // // Handle non-commutative condition codes first. // Note that we don't want to do this when we have a G_AND because it can // become a tst. The tst will make the test bit in the TB(N)Z redundant. - if (VRegAndVal && LHSMI->getOpcode() != TargetOpcode::G_AND) { + if (VRegAndVal && !AndInst) { int64_t C = VRegAndVal->Value; // When we have a greater-than comparison, we can just test if the msb is @@ -1507,14 +1509,19 @@ if (!VRegAndVal) { std::swap(RHS, LHS); VRegAndVal = getConstantVRegValWithLookThrough(RHS, MRI); - LHSMI = getDefIgnoringCopies(LHS, MRI); + AndInst = getOpcodeDef(TargetOpcode::G_AND, LHS, MRI); } if (VRegAndVal && VRegAndVal->Value == 0) { // If there's a G_AND feeding into this branch, try to fold it away by // emitting a TB(N)Z instead. - if (tryOptAndIntoCompareBranch(LHSMI, VRegAndVal->Value, Pred, DestMBB, - MIB)) { + // + // Note: If we have LT, then it *is* possible to fold, but it wouldn't be + // beneficial. When we have an AND and LT, we need a TST/ANDS, so folding + // would be redundant. + if (AndInst && + tryOptAndIntoCompareBranch( + *AndInst, /*Invert = */ Pred == CmpInst::ICMP_NE, DestMBB, MIB)) { I.eraseFromParent(); return true; } @@ -1529,15 +1536,69 @@ } } - // Couldn't optimize. Emit a compare + bcc. - emitIntegerCompare(CCMI->getOperand(2), CCMI->getOperand(3), - CCMI->getOperand(1), MIB); - const AArch64CC::CondCode CC = changeICMPPredToAArch64CC(Pred); + return false; +} + +bool AArch64InstructionSelector::selectCompareBranchFedByICmp( + MachineInstr &I, MachineInstr &ICmp, MachineIRBuilder &MIB) const { + assert(ICmp.getOpcode() == TargetOpcode::G_ICMP); + assert(I.getOpcode() == TargetOpcode::G_BRCOND); + if (tryOptCompareBranchFedByICmp(I, ICmp, MIB)) + return true; + + // Couldn't optimize. Emit a compare + a Bcc. + MachineBasicBlock *DestMBB = I.getOperand(1).getMBB(); + auto PredOp = ICmp.getOperand(1); + emitIntegerCompare(ICmp.getOperand(2), ICmp.getOperand(3), PredOp, MIB); + const AArch64CC::CondCode CC = changeICMPPredToAArch64CC( + static_cast(PredOp.getPredicate())); MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC).addMBB(DestMBB); I.eraseFromParent(); return true; } +bool AArch64InstructionSelector::selectCompareBranch( + MachineInstr &I, MachineFunction &MF, MachineRegisterInfo &MRI) const { + Register CondReg = I.getOperand(0).getReg(); + MachineInstr *CCMI = MRI.getVRegDef(CondReg); + if (CCMI->getOpcode() == TargetOpcode::G_TRUNC) { + CondReg = CCMI->getOperand(1).getReg(); + CCMI = MRI.getVRegDef(CondReg); + } + + // Try to select the G_BRCOND using whatever is feeding the condition if + // possible. + MachineIRBuilder MIB(I); + switch (CCMI->getOpcode()) { + default: + break; + case TargetOpcode::G_FCMP: + return selectCompareBranchFedByFCmp(I, *CCMI, MIB); + case TargetOpcode::G_ICMP: + return selectCompareBranchFedByICmp(I, *CCMI, MIB); + } + + // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z + // instructions will not be produced, as they are conditional branch + // instructions that do not set flags. + if (ProduceNonFlagSettingCondBr) { + emitTestBit(CondReg, /*Bit = */ 0, /*IsNegative = */ true, + I.getOperand(1).getMBB(), MIB); + I.eraseFromParent(); + return true; + } + + // Can't emit TB(N)Z/CB(N)Z. Emit a tst + bcc instead. + auto TstMI = + MIB.buildInstr(AArch64::ANDSWri, {LLT::scalar(32)}, {CondReg}).addImm(1); + constrainSelectedInstRegOperands(*TstMI, TII, TRI, RBI); + auto Bcc = MIB.buildInstr(AArch64::Bcc) + .addImm(AArch64CC::EQ) + .addMBB(I.getOperand(1).getMBB()); + I.eraseFromParent(); + return constrainSelectedInstRegOperands(*Bcc, TII, TRI, RBI); +} + /// Returns the element immediate value of a vector shift operand if found. /// This needs to detect a splat-like operation, e.g. a G_BUILD_VECTOR. static Optional getVectorShiftImm(Register Reg, @@ -2107,31 +2168,8 @@ MachineIRBuilder MIB(I); switch (Opcode) { - case TargetOpcode::G_BRCOND: { - Register CondReg = I.getOperand(0).getReg(); - MachineBasicBlock *DestMBB = I.getOperand(1).getMBB(); - - // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z - // instructions will not be produced, as they are conditional branch - // instructions that do not set flags. - if (ProduceNonFlagSettingCondBr && selectCompareBranch(I, MF, MRI)) - return true; - - if (ProduceNonFlagSettingCondBr) { - auto TestBit = emitTestBit(CondReg, /*Bit = */ 0, /*IsNegative = */ true, - DestMBB, MIB); - I.eraseFromParent(); - return constrainSelectedInstRegOperands(*TestBit, TII, TRI, RBI); - } else { - auto CMP = MIB.buildInstr(AArch64::ANDSWri, {LLT::scalar(32)}, {CondReg}) - .addImm(1); - constrainSelectedInstRegOperands(*CMP, TII, TRI, RBI); - auto Bcc = - MIB.buildInstr(AArch64::Bcc).addImm(AArch64CC::EQ).addMBB(DestMBB); - I.eraseFromParent(); - return constrainSelectedInstRegOperands(*Bcc.getInstr(), TII, TRI, RBI); - } - } + case TargetOpcode::G_BRCOND: + return selectCompareBranch(I, MF, MRI); case TargetOpcode::G_BRINDIRECT: { I.setDesc(TII.get(AArch64::BR)); Index: llvm/test/CodeGen/AArch64/GlobalISel/speculative-hardening-brcond.mir =================================================================== --- llvm/test/CodeGen/AArch64/GlobalISel/speculative-hardening-brcond.mir +++ llvm/test/CodeGen/AArch64/GlobalISel/speculative-hardening-brcond.mir @@ -8,6 +8,7 @@ --- | define void @no_tbnz() speculative_load_hardening { ret void } define void @no_cbz() speculative_load_hardening { ret void } + define void @fp() speculative_load_hardening { ret void } ... --- @@ -44,8 +45,6 @@ ; CHECK: successors: %bb.0(0x40000000), %bb.1(0x40000000) ; CHECK: %reg:gpr32sp = COPY $w0 ; CHECK: [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri %reg, 0, 0, implicit-def $nzcv - ; CHECK: %cmp:gpr32 = CSINCWr $wzr, $wzr, 1, implicit $nzcv - ; CHECK: [[ANDSWri:%[0-9]+]]:gpr32 = ANDSWri %cmp, 1, implicit-def $nzcv ; CHECK: Bcc 0, %bb.1, implicit $nzcv ; CHECK: B %bb.0 ; CHECK: bb.1: @@ -62,3 +61,29 @@ bb.1: RET_ReallyLR ... +--- +name: fp +legalized: true +regBankSelected: true +body: | + ; CHECK-LABEL: name: fp + ; CHECK: bb.0: + ; CHECK: successors: %bb.0(0x40000000), %bb.1(0x40000000) + ; CHECK: %reg0:fpr32 = COPY $s0 + ; CHECK: %reg1:fpr32 = COPY $s1 + ; CHECK: FCMPSrr %reg0, %reg1, implicit-def $nzcv + ; CHECK: Bcc 0, %bb.1, implicit $nzcv + ; CHECK: B %bb.0 + ; CHECK: bb.1: + ; CHECK: RET_ReallyLR + bb.0: + liveins: $s0, $s1 + successors: %bb.0, %bb.1 + %reg0:fpr(s32) = COPY $s0 + %reg1:fpr(s32) = COPY $s1 + %cmp:gpr(s32) = G_FCMP floatpred(oeq), %reg0, %reg1 + %cond:gpr(s1) = G_TRUNC %cmp(s32) + G_BRCOND %cond(s1), %bb.1 + G_BR %bb.0 + bb.1: + RET_ReallyLR