Index: lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- lib/Target/RISCV/RISCVISelLowering.h +++ lib/Target/RISCV/RISCVISelLowering.h @@ -25,7 +25,8 @@ enum NodeType : unsigned { FIRST_NUMBER = ISD::BUILTIN_OP_END, RET_FLAG, - CALL + CALL, + SELECT_CC }; } @@ -42,6 +43,10 @@ // This method returns the name of a target specific DAG node. const char *getTargetNodeName(unsigned Opcode) const override; + MachineBasicBlock * + EmitInstrWithCustomInserter(MachineInstr &MI, + MachineBasicBlock *BB) const override; + private: // Lower incoming arguments, copy physregs into vregs SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, @@ -60,6 +65,7 @@ return true; } SDValue lowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) const; }; } Index: lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- lib/Target/RISCV/RISCVISelLowering.cpp +++ lib/Target/RISCV/RISCVISelLowering.cpp @@ -56,6 +56,9 @@ setOperationAction(ISD::GlobalAddress, XLenVT, Custom); setOperationAction(ISD::BR_CC, XLenVT, Expand); + setOperationAction(ISD::SELECT, XLenVT, Custom); + setOperationAction(ISD::SELECT_CC, XLenVT, Expand); + setBooleanContents(ZeroOrOneBooleanContent); // Function alignments (log2). @@ -63,6 +66,45 @@ setPrefFunctionAlignment(3); } +// Changes the condition code and swaps operands if necessary, so the SetCC +// operation matches one of the comparisons supported directly in the RISC-V +// ISA. +static void normaliseSetCC(SDValue &LHS, SDValue &RHS, ISD::CondCode &CC) { + switch (CC) { + default: + break; + case ISD::SETGT: + case ISD::SETLE: + case ISD::SETUGT: + case ISD::SETULE: + CC = ISD::getSetCCSwappedOperands(CC); + std::swap(LHS, RHS); + break; + } +} + +// Return the RISC-V branch opcode that matches the given DAG integer +// condition code. The CondCode must be one of those supported by the RISC-V +// ISA (see normaliseSetCC). +static unsigned getBranchOpcodeForIntCondCode(ISD::CondCode CC) { + switch (CC) { + default: + llvm_unreachable("Unsupported CondCode"); + case ISD::SETEQ: + return RISCV::BEQ; + case ISD::SETNE: + return RISCV::BNE; + case ISD::SETLT: + return RISCV::BLT; + case ISD::SETGE: + return RISCV::BGE; + case ISD::SETULT: + return RISCV::BLTU; + case ISD::SETUGE: + return RISCV::BGEU; + } +} + SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -70,6 +112,8 @@ report_fatal_error("unimplemented operand"); case ISD::GlobalAddress: return lowerGlobalAddress(Op, DAG); + case ISD::SELECT: + return lowerSELECT(Op, DAG); } } @@ -95,6 +139,112 @@ } } +SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const { + SDValue CondV = Op.getOperand(0); + SDValue TrueV = Op.getOperand(1); + SDValue FalseV = Op.getOperand(2); + SDLoc DL(Op); + MVT XLenVT = Subtarget.getXLenVT(); + + // If the result type is XLenVT and CondV is the output of a SETCC node + // which also operated on XLenVT inputs, then merge the SETCC node into the + // lowered RISCVISD::SELECT_CC to take advantage of the integer + // compare+branch instructions. i.e.: + // (select (setcc lhs, rhs, cc), truev, falsev) + // -> (riscvisd::select_cc lhs, rhs, cc, truev, falsev) + if (Op.getSimpleValueType() == XLenVT && CondV.getOpcode() == ISD::SETCC && + CondV.getOperand(0).getSimpleValueType() == XLenVT) { + SDValue LHS = CondV.getOperand(0); + SDValue RHS = CondV.getOperand(1); + auto CC = cast(CondV.getOperand(2)); + ISD::CondCode CCVal = CC->get(); + + normaliseSetCC(LHS, RHS, CCVal); + + SDValue TargetCC = DAG.getConstant(CCVal, DL, XLenVT); + SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Glue); + SDValue Ops[] = {LHS, RHS, TargetCC, TrueV, FalseV}; + return DAG.getNode(RISCVISD::SELECT_CC, DL, VTs, Ops); + } + + // Otherwise: + // (select condv, truev, falsev) + // -> (riscvisd::select_cc condv, zero, setne, truev, falsev) + SDValue Zero = DAG.getConstant(0, DL, XLenVT); + SDValue SetNE = DAG.getConstant(ISD::SETNE, DL, XLenVT); + + SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Glue); + SDValue Ops[] = {CondV, Zero, SetNE, TrueV, FalseV}; + + return DAG.getNode(RISCVISD::SELECT_CC, DL, VTs, Ops); +} + +MachineBasicBlock * +RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, + MachineBasicBlock *BB) const { + const TargetInstrInfo &TII = *BB->getParent()->getSubtarget().getInstrInfo(); + DebugLoc DL = MI.getDebugLoc(); + + assert(MI.getOpcode() == RISCV::Select_GPR_Using_CC_GPR && + "Unexpected instr type to insert"); + + // To "insert" a SELECT instruction, we actually have to insert the triangle + // control-flow pattern. The incoming instruction knows the destination vreg + // to set, the condition code register to branch on, the true/false values to + // select between, and the condcode to use to select the appropriate branch. + // + // We produce the following control flow: + // HeadMBB + // | \ + // | IfFalseMBB + // | / + // TailMBB + const BasicBlock *LLVM_BB = BB->getBasicBlock(); + MachineFunction::iterator I = ++BB->getIterator(); + + MachineBasicBlock *HeadMBB = BB; + MachineFunction *F = BB->getParent(); + MachineBasicBlock *TailMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineBasicBlock *IfFalseMBB = F->CreateMachineBasicBlock(LLVM_BB); + + F->insert(I, IfFalseMBB); + F->insert(I, TailMBB); + // Move all remaining instructions to TailMBB. + TailMBB->splice(TailMBB->begin(), HeadMBB, + std::next(MachineBasicBlock::iterator(MI)), HeadMBB->end()); + // Update machine-CFG edges by transferring all successors of the current + // block to the new block which will contain the Phi node for the select. + TailMBB->transferSuccessorsAndUpdatePHIs(HeadMBB); + // Set the successors for HeadMBB. + HeadMBB->addSuccessor(IfFalseMBB); + HeadMBB->addSuccessor(TailMBB); + + // Insert appropriate branch. + unsigned LHS = MI.getOperand(1).getReg(); + unsigned RHS = MI.getOperand(2).getReg(); + auto CC = static_cast(MI.getOperand(3).getImm()); + unsigned Opcode = getBranchOpcodeForIntCondCode(CC); + + BuildMI(HeadMBB, DL, TII.get(Opcode)) + .addReg(LHS) + .addReg(RHS) + .addMBB(TailMBB); + + // IfFalseMBB just falls through to TailMBB. + IfFalseMBB->addSuccessor(TailMBB); + + // %Result = phi [ %TrueValue, HeadMBB ], [ %FalseValue, IfFalseMBB ] + BuildMI(*TailMBB, TailMBB->begin(), DL, TII.get(RISCV::PHI), + MI.getOperand(0).getReg()) + .addReg(MI.getOperand(4).getReg()) + .addMBB(HeadMBB) + .addReg(MI.getOperand(5).getReg()) + .addMBB(IfFalseMBB); + + MI.eraseFromParent(); // The pseudo instruction is gone now. + return TailMBB; +} + // Calling Convention Implementation. #include "RISCVGenCallingConv.inc" @@ -326,6 +476,8 @@ return "RISCVISD::RET_FLAG"; case RISCVISD::CALL: return "RISCVISD::CALL"; + case RISCVISD::SELECT_CC: + return "RISCVISD::SELECT_CC"; } return nullptr; } Index: lib/Target/RISCV/RISCVInstrInfo.td =================================================================== --- lib/Target/RISCV/RISCVInstrInfo.td +++ lib/Target/RISCV/RISCVInstrInfo.td @@ -22,6 +22,9 @@ SDTCisVT<1, i32>]>; def SDT_RISCVCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; +def SDT_RISCVSelectCC : SDTypeProfile<1, 5, [SDTCisSameAs<1, 2>, + SDTCisSameAs<0, 4>, + SDTCisSameAs<4, 5>]>; def Call : SDNode<"RISCVISD::CALL", SDT_RISCVCall, @@ -33,6 +36,8 @@ [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>; def RetFlag : SDNode<"RISCVISD::RET_FLAG", SDTNone, [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>; +def SelectCC : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC, + [SDNPInGlue]>; //===----------------------------------------------------------------------===// // Operand and SDNode transformation definitions. @@ -100,6 +105,9 @@ let DecoderMethod = "decodeSImmOperandAndLsl1<21>"; } +// A parameterized register class alternative for i32imm/i64imm from Target.td. +def ixlenimm : Operand; + // Standalone (codegen-only) immleaf patterns. def simm32 : ImmLeaf(Imm);}]>; @@ -320,6 +328,13 @@ def : PatGprGpr; def : PatGprSimm12; +let usesCustomInserter = 1 in +def Select_GPR_Using_CC_GPR + : Pseudo<(outs GPR:$dst), + (ins GPR:$lhs, GPR:$rhs, ixlenimm:$imm, GPR:$src, GPR:$src2), + [(set XLenVT:$dst, (SelectCC GPR:$lhs, GPR:$rhs, + (XLenVT imm:$imm), GPR:$src, GPR:$src2))]>; + /// Branches and jumps // Match `(brcond (CondOp ..), ..)` and lower to the appropriate RISC-V branch Index: test/CodeGen/RISCV/bare-select.ll =================================================================== --- /dev/null +++ test/CodeGen/RISCV/bare-select.ll @@ -0,0 +1,18 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s \ +; RUN: | FileCheck %s -check-prefix=RV32I + +define i32 @bare_select(i1 %a, i32 %b, i32 %c) { +; RV32I-LABEL: bare_select: +; RV32I: # BB#0: +; RV32I-NEXT: andi a0, a0, 1 +; RV32I-NEXT: addi a3, zero, 0 +; RV32I-NEXT: bne a0, a3, .LBB0_2 +; RV32I-NEXT: # BB#1: +; RV32I-NEXT: addi a1, a2, 0 +; RV32I-NEXT: .LBB0_2: +; RV32I-NEXT: addi a0, a1, 0 +; RV32I-NEXT: jalr zero, ra, 0 + %1 = select i1 %a, i32 %b, i32 %c + ret i32 %1 +} Index: test/CodeGen/RISCV/select-cc.ll =================================================================== --- /dev/null +++ test/CodeGen/RISCV/select-cc.ll @@ -0,0 +1,100 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s \ +; RUN: | FileCheck -check-prefix=RV32I %s + +define i32 @foo(i32 %a, i32 *%b) { +; RV32I-LABEL: foo: +; RV32I: # BB#0: +; RV32I-NEXT: lw a2, 0(a1) +; RV32I-NEXT: beq a0, a2, .LBB0_2 +; RV32I-NEXT: # BB#1: +; RV32I-NEXT: addi a0, a2, 0 +; RV32I-NEXT: .LBB0_2: +; RV32I-NEXT: lw a2, 0(a1) +; RV32I-NEXT: bne a0, a2, .LBB0_4 +; RV32I-NEXT: # BB#3: +; RV32I-NEXT: addi a0, a2, 0 +; RV32I-NEXT: .LBB0_4: +; RV32I-NEXT: lw a2, 0(a1) +; RV32I-NEXT: bltu a2, a0, .LBB0_6 +; RV32I-NEXT: # BB#5: +; RV32I-NEXT: addi a0, a2, 0 +; RV32I-NEXT: .LBB0_6: +; RV32I-NEXT: lw a2, 0(a1) +; RV32I-NEXT: bgeu a0, a2, .LBB0_8 +; RV32I-NEXT: # BB#7: +; RV32I-NEXT: addi a0, a2, 0 +; RV32I-NEXT: .LBB0_8: +; RV32I-NEXT: lw a2, 0(a1) +; RV32I-NEXT: bltu a0, a2, .LBB0_10 +; RV32I-NEXT: # BB#9: +; RV32I-NEXT: addi a0, a2, 0 +; RV32I-NEXT: .LBB0_10: +; RV32I-NEXT: lw a2, 0(a1) +; RV32I-NEXT: bgeu a2, a0, .LBB0_12 +; RV32I-NEXT: # BB#11: +; RV32I-NEXT: addi a0, a2, 0 +; RV32I-NEXT: .LBB0_12: +; RV32I-NEXT: lw a2, 0(a1) +; RV32I-NEXT: blt a2, a0, .LBB0_14 +; RV32I-NEXT: # BB#13: +; RV32I-NEXT: addi a0, a2, 0 +; RV32I-NEXT: .LBB0_14: +; RV32I-NEXT: lw a2, 0(a1) +; RV32I-NEXT: bge a0, a2, .LBB0_16 +; RV32I-NEXT: # BB#15: +; RV32I-NEXT: addi a0, a2, 0 +; RV32I-NEXT: .LBB0_16: +; RV32I-NEXT: lw a2, 0(a1) +; RV32I-NEXT: blt a0, a2, .LBB0_18 +; RV32I-NEXT: # BB#17: +; RV32I-NEXT: addi a0, a2, 0 +; RV32I-NEXT: .LBB0_18: +; RV32I-NEXT: lw a1, 0(a1) +; RV32I-NEXT: bge a1, a0, .LBB0_20 +; RV32I-NEXT: # BB#19: +; RV32I-NEXT: addi a0, a1, 0 +; RV32I-NEXT: .LBB0_20: +; RV32I-NEXT: jalr zero, ra, 0 + %val1 = load volatile i32, i32* %b + %tst1 = icmp eq i32 %a, %val1 + %val2 = select i1 %tst1, i32 %a, i32 %val1 + + %val3 = load volatile i32, i32* %b + %tst2 = icmp ne i32 %val2, %val3 + %val4 = select i1 %tst2, i32 %val2, i32 %val3 + + %val5 = load volatile i32, i32* %b + %tst3 = icmp ugt i32 %val4, %val5 + %val6 = select i1 %tst3, i32 %val4, i32 %val5 + + %val7 = load volatile i32, i32* %b + %tst4 = icmp uge i32 %val6, %val7 + %val8 = select i1 %tst4, i32 %val6, i32 %val7 + + %val9 = load volatile i32, i32* %b + %tst5 = icmp ult i32 %val8, %val9 + %val10 = select i1 %tst5, i32 %val8, i32 %val9 + + %val11 = load volatile i32, i32* %b + %tst6 = icmp ule i32 %val10, %val11 + %val12 = select i1 %tst6, i32 %val10, i32 %val11 + + %val13 = load volatile i32, i32* %b + %tst7 = icmp sgt i32 %val12, %val13 + %val14 = select i1 %tst7, i32 %val12, i32 %val13 + + %val15 = load volatile i32, i32* %b + %tst8 = icmp sge i32 %val14, %val15 + %val16 = select i1 %tst8, i32 %val14, i32 %val15 + + %val17 = load volatile i32, i32* %b + %tst9 = icmp slt i32 %val16, %val17 + %val18 = select i1 %tst9, i32 %val16, i32 %val17 + + %val19 = load volatile i32, i32* %b + %tst10 = icmp sle i32 %val18, %val19 + %val20 = select i1 %tst10, i32 %val18, i32 %val19 + + ret i32 %val20 +}