diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1068,6 +1068,8 @@ setTargetDAGCombine(ISD::AND); setTargetDAGCombine(ISD::OR); setTargetDAGCombine(ISD::XOR); + setTargetDAGCombine(ISD::ROTL); + setTargetDAGCombine(ISD::ROTR); setTargetDAGCombine(ISD::ANY_EXTEND); if (Subtarget.hasStdExtF()) { setTargetDAGCombine(ISD::ZERO_EXTEND); @@ -7269,6 +7271,40 @@ return DAG.getNode(ISD::SHL, DL, VT, NA1, DAG.getConstant(Bits, DL, VT)); } +// Combine +// ROTR ((GREV x, 24), 16) -> (GREVI x, 8) +// ROTL ((GREV x, 24), 16) -> (GREVI x, 8) +// RORW ((GREVW x, 24), 16) -> (GREVIW x, 8) +// ROLW ((GREVW x, 24), 16) -> (GREVIW x, 8) +static SDValue combineROTR_ROTL_RORW_ROLW(SDNode *N, SelectionDAG &DAG) { + SDValue Src = N->getOperand(0); + SDLoc DL(N); + unsigned Opc; + + if ((N->getOpcode() == ISD::ROTR || N->getOpcode() == ISD::ROTL) && + Src.getOpcode() == RISCVISD::GREV) + Opc = RISCVISD::GREV; + else if ((N->getOpcode() == RISCVISD::RORW || + N->getOpcode() == RISCVISD::ROLW) && + Src.getOpcode() == RISCVISD::GREVW) + Opc = RISCVISD::GREVW; + else + return SDValue(); + + if (!isa(N->getOperand(1)) || + !isa(Src.getOperand(1))) + return SDValue(); + + unsigned ShAmt1 = N->getConstantOperandVal(1); + unsigned ShAmt2 = Src.getConstantOperandVal(1); + if (ShAmt1 != 16 && ShAmt2 != 24) + return SDValue(); + + Src = Src.getOperand(0); + return DAG.getNode(Opc, DL, N->getValueType(0), Src, + DAG.getConstant(8, DL, N->getOperand(1).getValueType())); +} + // Combine (GREVI (GREVI x, C2), C1) -> (GREVI x, C1^C2) when C1^C2 is // non-zero, and to x when it is. Any repeated GREVI stage undoes itself. // Combine (GORCI (GORCI x, C2), C1) -> (GORCI x, C1|C2). Repeated stage does @@ -7973,8 +8009,12 @@ if (SimplifyDemandedLowBitsHelper(0, 32) || SimplifyDemandedLowBitsHelper(1, 5)) return SDValue(N, 0); - break; + + return combineROTR_ROTL_RORW_ROLW(N, DAG); } + case ISD::ROTR: + case ISD::ROTL: + return combineROTR_ROTL_RORW_ROLW(N, DAG); case RISCVISD::CLZW: case RISCVISD::CTZW: { // Only the lower 32 bits of the first operand are read diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td @@ -880,10 +880,6 @@ def : PatGprImm; def : PatGprImm; -// FIXME: Move to DAG combine. -def : Pat<(riscv_rorw (riscv_grevw GPR:$rs1, 24), 16), (GREVIW GPR:$rs1, 8)>; -def : Pat<(riscv_rolw (riscv_grevw GPR:$rs1, 24), 16), (GREVIW GPR:$rs1, 8)>; - def : PatGprGpr; def : PatGprGpr; } // Predicates = [HasStdExtZbp, IsRV64] @@ -892,10 +888,6 @@ def : PatGprGpr; let Predicates = [HasStdExtZbp, IsRV32] in { -// FIXME : Move to DAG combine. -def : Pat<(i32 (rotr (riscv_grev GPR:$rs1, 24), (i32 16))), (GREVI GPR:$rs1, 8)>; -def : Pat<(i32 (rotl (riscv_grev GPR:$rs1, 24), (i32 16))), (GREVI GPR:$rs1, 8)>; - // We treat rev8 as a separate instruction, so match it directly. def : Pat<(i32 (riscv_grev GPR:$rs1, 24)), (REV8_RV32 GPR:$rs1)>; } // Predicates = [HasStdExtZbp, IsRV32]