diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp --- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -252,9 +252,10 @@ /// Emit a CSet for an integer compare. /// - /// \p DefReg is expected to be a 32-bit scalar register. + /// \p DefReg and \p SrcReg are expected to be 32-bit scalar registers. MachineInstr *emitCSetForICMP(Register DefReg, unsigned Pred, - MachineIRBuilder &MIRBuilder) const; + MachineIRBuilder &MIRBuilder, + Register SrcReg = AArch64::WZR) const; /// Emit a CSet for a FP compare. /// /// \p Dst is expected to be a 32-bit scalar register. @@ -2155,6 +2156,41 @@ I.setDesc(TII.get(TargetOpcode::COPY)); return true; } + + case TargetOpcode::G_ADD: { + // Check if this is being fed by a G_ICMP on either side. + // + // (cmp pred, x, y) + z + // + // In the above case, when the cmp is true, we increment z by 1. So, we can + // fold the add into the cset for the cmp by using cinc. + // + // FIXME: This would probably be a lot nicer in PostLegalizerLowering. + Register X = I.getOperand(1).getReg(); + + // Only handle scalars. Scalar G_ICMP is only legal for s32, so bail out + // early if we see it. + LLT Ty = MRI.getType(X); + if (Ty.isVector() || Ty.getSizeInBits() != 32) + return false; + + Register CmpReg = I.getOperand(2).getReg(); + MachineInstr *Cmp = getOpcodeDef(TargetOpcode::G_ICMP, CmpReg, MRI); + if (!Cmp) { + std::swap(X, CmpReg); + Cmp = getOpcodeDef(TargetOpcode::G_ICMP, CmpReg, MRI); + if (!Cmp) + return false; + } + MachineIRBuilder MIRBuilder(I); + auto Pred = + static_cast(Cmp->getOperand(1).getPredicate()); + emitIntegerCompare(Cmp->getOperand(2), Cmp->getOperand(3), + Cmp->getOperand(1), MIRBuilder); + emitCSetForICMP(I.getOperand(0).getReg(), Pred, MIRBuilder, X); + I.eraseFromParent(); + return true; + } default: return false; } @@ -4367,14 +4403,13 @@ MachineInstr * AArch64InstructionSelector::emitCSetForICMP(Register DefReg, unsigned Pred, - MachineIRBuilder &MIRBuilder) const { + MachineIRBuilder &MIRBuilder, + Register SrcReg) const { // CSINC increments the result when the predicate is false. Invert it. const AArch64CC::CondCode InvCC = changeICMPPredToAArch64CC( CmpInst::getInversePredicate((CmpInst::Predicate)Pred)); - auto I = - MIRBuilder - .buildInstr(AArch64::CSINCWr, {DefReg}, {Register(AArch64::WZR), Register(AArch64::WZR)}) - .addImm(InvCC); + auto I = MIRBuilder.buildInstr(AArch64::CSINCWr, {DefReg}, {SrcReg, SrcReg}) + .addImm(InvCC); constrainSelectedInstRegOperands(*I, TII, TRI, RBI); return &*I; } diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/select-cmp.mir b/llvm/test/CodeGen/AArch64/GlobalISel/select-cmp.mir --- a/llvm/test/CodeGen/AArch64/GlobalISel/select-cmp.mir +++ b/llvm/test/CodeGen/AArch64/GlobalISel/select-cmp.mir @@ -270,3 +270,91 @@ RET_ReallyLR implicit $w0 ... +--- +name: cmp_add_rhs +legalized: true +regBankSelected: true +tracksRegLiveness: true +machineFunctionInfo: {} +body: | + bb.0: + liveins: $w0, $w1, $w2 + + ; The CSINC should use the add's RHS. + + ; CHECK-LABEL: name: cmp_add_rhs + ; CHECK: liveins: $w0, $w1, $w2 + ; CHECK: %cmp_lhs:gpr32 = COPY $w0 + ; CHECK: %cmp_rhs:gpr32 = COPY $w1 + ; CHECK: %add_rhs:gpr32 = COPY $w2 + ; CHECK: [[SUBSWrr:%[0-9]+]]:gpr32 = SUBSWrr %cmp_lhs, %cmp_rhs, implicit-def $nzcv + ; CHECK: %add:gpr32 = CSINCWr %add_rhs, %add_rhs, 1, implicit $nzcv + ; CHECK: $w0 = COPY %add + ; CHECK: RET_ReallyLR implicit $w0 + %cmp_lhs:gpr(s32) = COPY $w0 + %cmp_rhs:gpr(s32) = COPY $w1 + %add_rhs:gpr(s32) = COPY $w2 + %cmp:gpr(s32) = G_ICMP intpred(eq), %cmp_lhs(s32), %cmp_rhs + %add:gpr(s32) = G_ADD %cmp, %add_rhs + $w0 = COPY %add(s32) + RET_ReallyLR implicit $w0 + +... +--- +name: cmp_add_lhs +legalized: true +regBankSelected: true +tracksRegLiveness: true +machineFunctionInfo: {} +body: | + bb.0: + liveins: $w0, $w1, $w2 + + ; The CSINC should use the add's LHS. + + ; CHECK-LABEL: name: cmp_add_lhs + ; CHECK: liveins: $w0, $w1, $w2 + ; CHECK: %cmp_lhs:gpr32 = COPY $w0 + ; CHECK: %cmp_rhs:gpr32 = COPY $w1 + ; CHECK: %add_lhs:gpr32 = COPY $w2 + ; CHECK: [[SUBSWrr:%[0-9]+]]:gpr32 = SUBSWrr %cmp_lhs, %cmp_rhs, implicit-def $nzcv + ; CHECK: %add:gpr32 = CSINCWr %add_lhs, %add_lhs, 1, implicit $nzcv + ; CHECK: $w0 = COPY %add + ; CHECK: RET_ReallyLR implicit $w0 + %cmp_lhs:gpr(s32) = COPY $w0 + %cmp_rhs:gpr(s32) = COPY $w1 + %add_lhs:gpr(s32) = COPY $w2 + %cmp:gpr(s32) = G_ICMP intpred(eq), %cmp_lhs(s32), %cmp_rhs + %add:gpr(s32) = G_ADD %add_lhs, %cmp + $w0 = COPY %add(s32) + RET_ReallyLR implicit $w0 + +... +--- +name: cmp_add_lhs_vector +legalized: true +regBankSelected: true +tracksRegLiveness: true +machineFunctionInfo: {} +body: | + bb.0: + liveins: $q0, $q1, $q2 + + ; We don't emit CSINC with vectors, so there should be no optimization here. + + ; CHECK-LABEL: name: cmp_add_lhs_vector + ; CHECK: liveins: $q0, $q1, $q2 + ; CHECK: %cmp_lhs:fpr128 = COPY $q0 + ; CHECK: %cmp_rhs:fpr128 = COPY $q1 + ; CHECK: %add_lhs:fpr128 = COPY $q2 + ; CHECK: [[CMEQv4i32_:%[0-9]+]]:fpr128 = CMEQv4i32 %cmp_lhs, %cmp_rhs + ; CHECK: %add:fpr128 = ADDv4i32 %add_lhs, [[CMEQv4i32_]] + ; CHECK: $q0 = COPY %add + ; CHECK: RET_ReallyLR implicit $q0 + %cmp_lhs:fpr(<4 x s32>) = COPY $q0 + %cmp_rhs:fpr(<4 x s32>) = COPY $q1 + %add_lhs:fpr(<4 x s32>) = COPY $q2 + %cmp:fpr(<4 x s32>) = G_ICMP intpred(eq), %cmp_lhs(<4 x s32>), %cmp_rhs + %add:fpr(<4 x s32>) = G_ADD %add_lhs, %cmp + $q0 = COPY %add(<4 x s32>) + RET_ReallyLR implicit $q0