Index: lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- lib/Target/RISCV/RISCVISelLowering.h +++ lib/Target/RISCV/RISCVISelLowering.h @@ -25,7 +25,8 @@ enum NodeType : unsigned { FIRST_NUMBER = ISD::BUILTIN_OP_END, RET_FLAG, - CALL + CALL, + SELECT_CC }; } @@ -42,6 +43,10 @@ // This method returns the name of a target specific DAG node. const char *getTargetNodeName(unsigned Opcode) const override; + MachineBasicBlock * + EmitInstrWithCustomInserter(MachineInstr &MI, + MachineBasicBlock *BB) const override; + private: // Lower incoming arguments, copy physregs into vregs SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, @@ -60,6 +65,7 @@ return true; } SDValue lowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const; }; } Index: lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- lib/Target/RISCV/RISCVISelLowering.cpp +++ lib/Target/RISCV/RISCVISelLowering.cpp @@ -50,6 +50,7 @@ // TODO: add all necessary setOperationAction calls. setOperationAction(ISD::BR_CC, MVT::i32, Expand); + setOperationAction(ISD::SELECT_CC, MVT::i32, Custom); setBooleanContents(ZeroOrOneBooleanContent); setOperationAction(ISD::GlobalAddress, MVT::i32, Custom); @@ -64,11 +65,52 @@ MaxStoresPerMemmove = MaxStoresPerMemmoveOptSize = 128; } +// Changes the condition code and swaps operands if necessary, so the SetCC +// operation matches one of the comparisons supported directly in the RISC-V +// ISA. +static void normaliseSetCC(SDValue &LHS, SDValue &RHS, ISD::CondCode &CC) { + switch (CC) { + default: + break; + case ISD::SETGT: + case ISD::SETLE: + case ISD::SETUGT: + case ISD::SETULE: + CC = ISD::getSetCCSwappedOperands(CC); + std::swap(LHS, RHS); + break; + } +} + +// Return the RISC-V branch opcode that matches the given DAG integer +// condition code. The CondCode must be one of those supported by the RISC-V +// ISA (see normaliseSetCC). +static unsigned getBranchOpcodeForIntCondCode(ISD::CondCode CC) { + switch (CC) { + default: + llvm_unreachable("Unsupported CondCode"); + case ISD::SETEQ: + return RISCV::BEQ; + case ISD::SETNE: + return RISCV::BNE; + case ISD::SETLT: + return RISCV::BLT; + case ISD::SETGE: + return RISCV::BGE; + case ISD::SETULT: + return RISCV::BLTU; + case ISD::SETUGE: + return RISCV::BGEU; + } +} + SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { case ISD::GlobalAddress: return lowerGlobalAddress(Op, DAG); + case ISD::SELECT_CC: + return lowerSELECT_CC(Op, DAG); default: report_fatal_error("unimplemented operand"); } @@ -96,6 +138,90 @@ } } +SDValue RISCVTargetLowering::lowerSELECT_CC(SDValue Op, + SelectionDAG &DAG) const { + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + SDValue TrueV = Op.getOperand(2); + SDValue FalseV = Op.getOperand(3); + ISD::CondCode CC = cast(Op.getOperand(4))->get(); + SDLoc DL(Op); + + normaliseSetCC(LHS, RHS, CC); + + SDValue TargetCC = DAG.getConstant(CC, DL, MVT::i32); + + SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Glue); + SDValue Ops[] = {LHS, RHS, TargetCC, TrueV, FalseV}; + + return DAG.getNode(RISCVISD::SELECT_CC, DL, VTs, Ops); +} + +MachineBasicBlock * +RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, + MachineBasicBlock *BB) const { + const TargetInstrInfo &TII = *BB->getParent()->getSubtarget().getInstrInfo(); + DebugLoc DL = MI.getDebugLoc(); + + assert(MI.getOpcode() == RISCV::Select && "Unexpected instr type to insert"); + + // To "insert" a SELECT instruction, we actually have to insert the triangle + // control-flow pattern. The incoming instruction knows the destination vreg + // to set, the condition code register to branch on, the true/false values to + // select between, and the condcode to use to select the appropriate branch. + // + // We produce the following control flow: + // HeadMBB + // | \ + // | IfFalseMBB + // | / + // TailMBB + const BasicBlock *LLVM_BB = BB->getBasicBlock(); + MachineFunction::iterator I = ++BB->getIterator(); + + MachineBasicBlock *HeadMBB = BB; + MachineFunction *F = BB->getParent(); + MachineBasicBlock *TailMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineBasicBlock *IfFalseMBB = F->CreateMachineBasicBlock(LLVM_BB); + + F->insert(I, IfFalseMBB); + F->insert(I, TailMBB); + // Move all remaining instructions to TailMBB. + TailMBB->splice(TailMBB->begin(), HeadMBB, + std::next(MachineBasicBlock::iterator(MI)), HeadMBB->end()); + // Update machine-CFG edges by transferring all successors of the current + // block to the new block which will contain the Phi node for the select. + TailMBB->transferSuccessorsAndUpdatePHIs(HeadMBB); + // Set the successors for HeadMBB. + HeadMBB->addSuccessor(IfFalseMBB); + HeadMBB->addSuccessor(TailMBB); + + // Insert appropriate branch. + unsigned LHS = MI.getOperand(1).getReg(); + unsigned RHS = MI.getOperand(2).getReg(); + auto CC = static_cast(MI.getOperand(3).getImm()); + unsigned Opcode = getBranchOpcodeForIntCondCode(CC); + + BuildMI(HeadMBB, DL, TII.get(Opcode)) + .addReg(LHS) + .addReg(RHS) + .addMBB(TailMBB); + + // IfFalseMBB just falls through to TailMBB. + IfFalseMBB->addSuccessor(TailMBB); + + // %Result = phi [ %TrueValue, HeadMBB ], [ %FalseValue, IfFalseMBB ] + BuildMI(*TailMBB, TailMBB->begin(), DL, TII.get(RISCV::PHI), + MI.getOperand(0).getReg()) + .addReg(MI.getOperand(4).getReg()) + .addMBB(HeadMBB) + .addReg(MI.getOperand(5).getReg()) + .addMBB(IfFalseMBB); + + MI.eraseFromParent(); // The pseudo instruction is gone now. + return TailMBB; +} + // Calling Convention Implementation. #include "RISCVGenCallingConv.inc" @@ -330,6 +456,8 @@ return "RISCVISD::RET_FLAG"; case RISCVISD::CALL: return "RISCVISD::CALL"; + case RISCVISD::SELECT_CC: + return "RISCVISD::SELECT_CC"; } return nullptr; } Index: lib/Target/RISCV/RISCVInstrInfo.td =================================================================== --- lib/Target/RISCV/RISCVInstrInfo.td +++ lib/Target/RISCV/RISCVInstrInfo.td @@ -18,6 +18,9 @@ def SDT_RISCVCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; def SDT_RISCVCall : SDTypeProfile<0, -1, [SDTCisVT<0, i32>]>; +def SDT_RISCVSelectCC : SDTypeProfile<1, 5, [SDTCisSameAs<1, 2>, + SDTCisSameAs<0, 4>, + SDTCisSameAs<4, 5>]>; def Call : SDNode<"RISCVISD::CALL", SDT_RISCVCall, @@ -29,6 +32,7 @@ [SDNPHasChain, SDNPOutGlue]>; def CallSeqEnd : SDNode<"ISD::CALLSEQ_END", SDT_RISCVCallSeqEnd, [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>; +def SelectCC : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC, [SDNPInGlue]>; // Operands class ImmAsmOperand : AsmOperandClass { @@ -300,6 +304,12 @@ def : PatGprGpr; def : PatGprSimm12; +let usesCustomInserter = 1 in +def Select : Pseudo<(outs GPR:$dst), + (ins GPR:$lhs, GPR:$rhs, i32imm:$imm, GPR:$src, GPR:$src2), + [(set i32:$dst, + (SelectCC GPR:$lhs, GPR:$rhs, (i32 imm:$imm), GPR:$src, GPR:$src2))]>; + /// Branches and jumps // Match `(brcond (CondOp ..), ..)` and lower to the appropriate RISC-V branch Index: test/CodeGen/RISCV/select-cc.ll =================================================================== --- /dev/null +++ test/CodeGen/RISCV/select-cc.ll @@ -0,0 +1,56 @@ +; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s | FileCheck %s + +define i32 @foo(i32 %a, i32 *%b) { +; CHECK-LABEL: foo: +; CHECK: beq a0, a2, .LBB{{.+}} + %val1 = load volatile i32, i32* %b + %tst1 = icmp eq i32 %a, %val1 + %val2 = select i1 %tst1, i32 %a, i32 %val1 + +; CHECK: bne a0, a2, .LBB{{.+}} + %val3 = load volatile i32, i32* %b + %tst2 = icmp ne i32 %val2, %val3 + %val4 = select i1 %tst2, i32 %val2, i32 %val3 + +; CHECK: bltu a2, a0, .LBB{{.+}} + %val5 = load volatile i32, i32* %b + %tst3 = icmp ugt i32 %val4, %val5 + %val6 = select i1 %tst3, i32 %val4, i32 %val5 + +; CHECK: bgeu a0, a2, .LBB{{.+}} + %val7 = load volatile i32, i32* %b + %tst4 = icmp uge i32 %val6, %val7 + %val8 = select i1 %tst4, i32 %val6, i32 %val7 + +; CHECK: bltu a0, a2, .LBB{{.+}} + %val9 = load volatile i32, i32* %b + %tst5 = icmp ult i32 %val8, %val9 + %val10 = select i1 %tst5, i32 %val8, i32 %val9 + +; CHECK: bgeu a2, a0, .LBB{{.+}} + %val11 = load volatile i32, i32* %b + %tst6 = icmp ule i32 %val10, %val11 + %val12 = select i1 %tst6, i32 %val10, i32 %val11 + +; CHECK: blt a2, a0, .LBB{{.+}} + %val13 = load volatile i32, i32* %b + %tst7 = icmp sgt i32 %val12, %val13 + %val14 = select i1 %tst7, i32 %val12, i32 %val13 + +; CHECK: bge a0, a2, .LBB{{.+}} + %val15 = load volatile i32, i32* %b + %tst8 = icmp sge i32 %val14, %val15 + %val16 = select i1 %tst8, i32 %val14, i32 %val15 + +; CHECK: blt a0, a2, .LBB{{.+}} + %val17 = load volatile i32, i32* %b + %tst9 = icmp slt i32 %val16, %val17 + %val18 = select i1 %tst9, i32 %val16, i32 %val17 + +; CHECK: bge a1, a0, .LBB{{.+}} + %val19 = load volatile i32, i32* %b + %tst10 = icmp sle i32 %val18, %val19 + %val20 = select i1 %tst10, i32 %val18, i32 %val19 + + ret i32 %val20 +}