diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -36,6 +36,7 @@ /// The lhs and rhs are XLenVT integers. The true and false values can be /// integer or floating point. SELECT_CC, + BR_CC, BuildPairF64, SplitF64, TAIL, @@ -441,6 +442,7 @@ SDValue lowerJumpTable(SDValue Op, SelectionDAG &DAG) const; SDValue lowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const; SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerBRCOND(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVASTART(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const; SDValue lowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -181,6 +181,7 @@ setOperationAction(ISD::BR_JT, MVT::Other, Expand); setOperationAction(ISD::BR_CC, XLenVT, Expand); + setOperationAction(ISD::BRCOND, MVT::Other, Custom); setOperationAction(ISD::SELECT_CC, XLenVT, Expand); setOperationAction(ISD::STACKSAVE, MVT::Other, Expand); @@ -673,7 +674,6 @@ // We can use any register for comparisons setHasMultipleConditionRegisters(); - setTargetDAGCombine(ISD::SETCC); if (Subtarget.hasStdExtZbp()) { setTargetDAGCombine(ISD::OR); } @@ -1204,6 +1204,8 @@ return lowerGlobalTLSAddress(Op, DAG); case ISD::SELECT: return lowerSELECT(Op, DAG); + case ISD::BRCOND: + return lowerBRCOND(Op, DAG); case ISD::VASTART: return lowerVASTART(Op, DAG); case ISD::FRAMEADDR: @@ -1879,11 +1881,10 @@ // (select (setcc lhs, rhs, cc), truev, falsev) // -> (riscvisd::select_cc lhs, rhs, cc, truev, falsev) if (Op.getSimpleValueType() == XLenVT && CondV.getOpcode() == ISD::SETCC && - CondV.getOperand(0).getSimpleValueType() == XLenVT) { + CondV.getOperand(0).getValueType() == XLenVT) { SDValue LHS = CondV.getOperand(0); SDValue RHS = CondV.getOperand(1); - auto CC = cast(CondV.getOperand(2)); - ISD::CondCode CCVal = CC->get(); + ISD::CondCode CCVal = cast(CondV.getOperand(2))->get(); translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG); @@ -1903,6 +1904,29 @@ return DAG.getNode(RISCVISD::SELECT_CC, DL, Op.getValueType(), Ops); } +SDValue RISCVTargetLowering::lowerBRCOND(SDValue Op, SelectionDAG &DAG) const { + SDValue CondV = Op.getOperand(1); + SDLoc DL(Op); + MVT XLenVT = Subtarget.getXLenVT(); + + if (CondV.getOpcode() == ISD::SETCC && + CondV.getOperand(0).getValueType() == XLenVT) { + SDValue LHS = CondV.getOperand(0); + SDValue RHS = CondV.getOperand(1); + ISD::CondCode CCVal = cast(CondV.getOperand(2))->get(); + + translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG); + + SDValue TargetCC = DAG.getCondCode(CCVal); + return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0), + LHS, RHS, TargetCC, Op.getOperand(2)); + } + + return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0), + CondV, DAG.getConstant(0, DL, XLenVT), + DAG.getCondCode(ISD::SETNE), Op.getOperand(2)); +} + SDValue RISCVTargetLowering::lowerVASTART(SDValue Op, SelectionDAG &DAG) const { MachineFunction &MF = DAG.getMachineFunction(); RISCVMachineFunctionInfo *FuncInfo = MF.getInfo(); @@ -4206,21 +4230,53 @@ break; } - case ISD::SETCC: { - // (setcc X, 1, setne) -> (setcc X, 0, seteq) if we can prove X is 0/1. - // Comparing with 0 may allow us to fold into bnez/beqz. - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - if (LHS.getValueType().isScalableVector()) + case RISCVISD::BR_CC: { + SDValue LHS = N->getOperand(1); + SDValue RHS = N->getOperand(2); + ISD::CondCode CCVal = cast(N->getOperand(3))->get(); + if (!ISD::isIntEqualitySetCC(CCVal)) break; - auto CC = cast(N->getOperand(2))->get(); + + // Fold (br_cc (setlt X, Y), 0, ne, dest) -> + // (br_cc X, Y, lt, dest) + // Sometimes the setcc is introduced after br_cc has been formed. + if (LHS.getOpcode() == ISD::SETCC && isNullConstant(RHS) && + LHS.getOperand(0).getValueType() == Subtarget.getXLenVT()) { + // If we're looking for eq 0 instead of ne 0, we need to invert the + // condition. + bool Invert = CCVal == ISD::SETEQ; + CCVal = cast(LHS.getOperand(2))->get(); + if (Invert) + CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType()); + + SDLoc DL(N); + RHS = LHS.getOperand(1); + LHS = LHS.getOperand(0); + translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG); + + return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0), + N->getOperand(0), LHS, RHS, DAG.getCondCode(CCVal), + N->getOperand(4)); + } + + // Fold (br_cc (xor X, Y), 0, eq/ne, dest) -> + // (br_cc X, Y, eq/ne, trueV, falseV) + if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS)) + return DAG.getNode(RISCVISD::BR_CC, SDLoc(N), N->getValueType(0), + N->getOperand(0), LHS.getOperand(0), LHS.getOperand(1), + N->getOperand(3), N->getOperand(4)); + // (br_cc X, 1, setne, br_cc) -> + // (br_cc X, 0, seteq, br_cc) if we can prove X is 0/1. + // This can occur when legalizing some floating point comparisons. APInt Mask = APInt::getBitsSetFrom(LHS.getValueSizeInBits(), 1); - if (isOneConstant(RHS) && ISD::isIntEqualitySetCC(CC) && - DAG.MaskedValueIsZero(LHS, Mask)) { + if (isOneConstant(RHS) && DAG.MaskedValueIsZero(LHS, Mask)) { SDLoc DL(N); - SDValue Zero = DAG.getConstant(0, DL, LHS.getValueType()); - CC = ISD::getSetCCInverse(CC, LHS.getValueType()); - return DAG.getSetCC(DL, N->getValueType(0), LHS, Zero, CC); + CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType()); + SDValue TargetCC = DAG.getCondCode(CCVal); + RHS = DAG.getConstant(0, DL, LHS.getValueType()); + return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0), + N->getOperand(0), LHS, RHS, TargetCC, + N->getOperand(4)); } break; } @@ -4381,6 +4437,17 @@ Known.resetAll(); switch (Opc) { default: break; + case RISCVISD::SELECT_CC: { + Known = DAG.computeKnownBits(Op.getOperand(4), Depth + 1); + // If we don't know any bits, early out. + if (Known.isUnknown()) + break; + KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(3), Depth + 1); + + // Only known if known in both the LHS and RHS. + Known = KnownBits::commonBits(Known, Known2); + break; + } case RISCVISD::REMUW: { KnownBits Known2; Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); @@ -6103,6 +6170,7 @@ NODE_NAME_CASE(MRET_FLAG) NODE_NAME_CASE(CALL) NODE_NAME_CASE(SELECT_CC) + NODE_NAME_CASE(BR_CC) NODE_NAME_CASE(BuildPairF64) NODE_NAME_CASE(SplitF64) NODE_NAME_CASE(TAIL) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -25,6 +25,9 @@ def SDT_RISCVSelectCC : SDTypeProfile<1, 5, [SDTCisSameAs<1, 2>, SDTCisSameAs<0, 4>, SDTCisSameAs<4, 5>]>; +def SDT_RISCVBrCC : SDTypeProfile<0, 4, [SDTCisSameAs<0, 1>, + SDTCisVT<2, OtherVT>, + SDTCisVT<3, OtherVT>]>; def SDT_RISCVReadCycleWide : SDTypeProfile<2, 0, [SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; def SDT_RISCVIntBinOpW : SDTypeProfile<1, 2, [ @@ -50,6 +53,8 @@ def riscv_mret_flag : SDNode<"RISCVISD::MRET_FLAG", SDTNone, [SDNPHasChain, SDNPOptInGlue]>; def riscv_selectcc : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC>; +def riscv_brcc : SDNode<"RISCVISD::BR_CC", SDT_RISCVBrCC, + [SDNPHasChain]>; def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall, [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, SDNPVariadic]>; @@ -961,41 +966,17 @@ /// Branches and jumps -// Match `(brcond (CondOp ..), ..)` and lower to the appropriate RISC-V branch -// instruction. -class BccPat - : Pat<(brcond (XLenVT (CondOp GPR:$rs1, GPR:$rs2)), bb:$imm12), +// Match `riscv_brcc` and lower to the appropriate RISC-V branch instruction. +class BccPat + : Pat<(riscv_brcc GPR:$rs1, GPR:$rs2, Cond, bb:$imm12), (Inst GPR:$rs1, GPR:$rs2, simm13_lsb0:$imm12)>; -def : BccPat; -def : BccPat; -def : BccPat; -def : BccPat; -def : BccPat; -def : BccPat; - -class BccSwapPat - : Pat<(brcond (XLenVT (CondOp GPR:$rs1, GPR:$rs2)), bb:$imm12), - (InstBcc GPR:$rs2, GPR:$rs1, bb:$imm12)>; - -// Condition codes that don't have matching RISC-V branch instructions, but -// are trivially supported by swapping the two input operands -def : BccSwapPat; -def : BccSwapPat; -def : BccSwapPat; -def : BccSwapPat; - -// Extra patterns are needed for a brcond without a setcc (i.e. where the -// condition was calculated elsewhere). -def : Pat<(brcond GPR:$cond, bb:$imm12), (BNE GPR:$cond, X0, bb:$imm12)>; -// In this pattern, the `(xor $cond, 1)` functions like (boolean) `not`, as the -// `brcond` only uses the lowest bit. -def : Pat<(brcond (XLenVT (xor GPR:$cond, 1)), bb:$imm12), - (BEQ GPR:$cond, X0, bb:$imm12)>; - -// Match X > -1, the canonical form of X >= 0, to the bgez pattern. -def : Pat<(brcond (XLenVT (setgt GPR:$rs1, -1)), bb:$imm12), - (BGE GPR:$rs1, X0, bb:$imm12)>; +def : BccPat; +def : BccPat; +def : BccPat; +def : BccPat; +def : BccPat; +def : BccPat; let isBarrier = 1, isBranch = 1, isTerminator = 1 in def PseudoBR : Pseudo<(outs), (ins simm21_lsb0_jal:$imm20), [(br bb:$imm20)]>, diff --git a/llvm/test/CodeGen/RISCV/hoist-global-addr-base.ll b/llvm/test/CodeGen/RISCV/hoist-global-addr-base.ll --- a/llvm/test/CodeGen/RISCV/hoist-global-addr-base.ll +++ b/llvm/test/CodeGen/RISCV/hoist-global-addr-base.ll @@ -29,8 +29,7 @@ ; CHECK-NEXT: lui a0, %hi(s) ; CHECK-NEXT: addi a0, a0, %lo(s) ; CHECK-NEXT: lw a1, 164(a0) -; CHECK-NEXT: addi a2, zero, 1 -; CHECK-NEXT: blt a1, a2, .LBB1_2 +; CHECK-NEXT: blez a1, .LBB1_2 ; CHECK-NEXT: # %bb.1: # %if.then ; CHECK-NEXT: addi a1, zero, 10 ; CHECK-NEXT: sw a1, 160(a0) diff --git a/llvm/test/CodeGen/RISCV/xaluo.ll b/llvm/test/CodeGen/RISCV/xaluo.ll --- a/llvm/test/CodeGen/RISCV/xaluo.ll +++ b/llvm/test/CodeGen/RISCV/xaluo.ll @@ -1451,8 +1451,7 @@ ; RV32-NEXT: add a2, a0, a1 ; RV32-NEXT: slt a0, a2, a0 ; RV32-NEXT: slti a1, a1, 0 -; RV32-NEXT: xor a0, a1, a0 -; RV32-NEXT: beqz a0, .LBB46_2 +; RV32-NEXT: beq a1, a0, .LBB46_2 ; RV32-NEXT: # %bb.1: # %overflow ; RV32-NEXT: mv a0, zero ; RV32-NEXT: ret @@ -1510,8 +1509,7 @@ ; RV64-NEXT: add a2, a0, a1 ; RV64-NEXT: slt a0, a2, a0 ; RV64-NEXT: slti a1, a1, 0 -; RV64-NEXT: xor a0, a1, a0 -; RV64-NEXT: beqz a0, .LBB47_2 +; RV64-NEXT: beq a1, a0, .LBB47_2 ; RV64-NEXT: # %bb.1: # %overflow ; RV64-NEXT: mv a0, zero ; RV64-NEXT: ret @@ -1620,8 +1618,7 @@ ; RV32-NEXT: sgtz a2, a1 ; RV32-NEXT: sub a1, a0, a1 ; RV32-NEXT: slt a0, a1, a0 -; RV32-NEXT: xor a0, a2, a0 -; RV32-NEXT: beqz a0, .LBB50_2 +; RV32-NEXT: beq a2, a0, .LBB50_2 ; RV32-NEXT: # %bb.1: # %overflow ; RV32-NEXT: mv a0, zero ; RV32-NEXT: ret @@ -1677,8 +1674,7 @@ ; RV64-NEXT: sgtz a2, a1 ; RV64-NEXT: sub a1, a0, a1 ; RV64-NEXT: slt a0, a1, a0 -; RV64-NEXT: xor a0, a2, a0 -; RV64-NEXT: beqz a0, .LBB51_2 +; RV64-NEXT: beq a2, a0, .LBB51_2 ; RV64-NEXT: # %bb.1: # %overflow ; RV64-NEXT: mv a0, zero ; RV64-NEXT: ret