diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -583,11 +583,11 @@ bool DemandHighBits = true); SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1); SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg, - SDValue InnerPos, SDValue InnerNeg, + SDValue InnerPos, SDValue InnerNeg, bool HasPos, unsigned PosOpcode, unsigned NegOpcode, const SDLoc &DL); SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg, - SDValue InnerPos, SDValue InnerNeg, + SDValue InnerPos, SDValue InnerNeg, bool HasPos, unsigned PosOpcode, unsigned NegOpcode, const SDLoc &DL); SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); @@ -7031,8 +7031,9 @@ // Neg with outer conversions stripped away. SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg, SDValue InnerPos, - SDValue InnerNeg, unsigned PosOpcode, - unsigned NegOpcode, const SDLoc &DL) { + SDValue InnerNeg, bool HasPos, + unsigned PosOpcode, unsigned NegOpcode, + const SDLoc &DL) { // fold (or (shl x, (*ext y)), // (srl x, (*ext (sub 32, y)))) -> // (rotl x, y) or (rotr x, (sub 32, y)) @@ -7043,7 +7044,6 @@ EVT VT = Shifted.getValueType(); if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG, /*IsRotate*/ true)) { - bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted, HasPos ? Pos : Neg); } @@ -7059,8 +7059,9 @@ // TODO: Merge with MatchRotatePosNeg. SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg, SDValue InnerPos, - SDValue InnerNeg, unsigned PosOpcode, - unsigned NegOpcode, const SDLoc &DL) { + SDValue InnerNeg, bool HasPos, + unsigned PosOpcode, unsigned NegOpcode, + const SDLoc &DL) { EVT VT = N0.getValueType(); unsigned EltBits = VT.getScalarSizeInBits(); @@ -7072,7 +7073,6 @@ // (srl x1, (*ext y))) -> // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y)) if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) { - bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1, HasPos ? Pos : Neg); } @@ -7134,6 +7134,16 @@ bool HasROTR = hasOperation(ISD::ROTR, VT); bool HasFSHL = hasOperation(ISD::FSHL, VT); bool HasFSHR = hasOperation(ISD::FSHR, VT); + + // If the type is going to be promoted and the target has enabled custom + // lowering for rotate, allow matching rotate by non-constants. Only allow + // this for scalar types. + if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) == + TargetLowering::TypePromoteInteger) { + HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom; + HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom; + } + if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR) return SDValue(); @@ -7276,26 +7286,26 @@ if (IsRotate && (HasROTL || HasROTR)) { SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0, - RExtOp0, ISD::ROTL, ISD::ROTR, DL); + RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL); if (TryL) return TryL; SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0, - LExtOp0, ISD::ROTR, ISD::ROTL, DL); + LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL); if (TryR) return TryR; } SDValue TryL = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt, - LExtOp0, RExtOp0, ISD::FSHL, ISD::FSHR, DL); + LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL); if (TryL) return TryL; SDValue TryR = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt, - RExtOp0, LExtOp0, ISD::FSHR, ISD::FSHL, DL); + RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL); if (TryR) return TryR; diff --git a/llvm/test/CodeGen/RISCV/rotl-rotr.ll b/llvm/test/CodeGen/RISCV/rotl-rotr.ll --- a/llvm/test/CodeGen/RISCV/rotl-rotr.ll +++ b/llvm/test/CodeGen/RISCV/rotl-rotr.ll @@ -40,11 +40,7 @@ ; ; RV64ZBB-LABEL: rotl_32: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: li a2, 32 -; RV64ZBB-NEXT: subw a2, a2, a1 -; RV64ZBB-NEXT: sllw a1, a0, a1 -; RV64ZBB-NEXT: srlw a0, a0, a2 -; RV64ZBB-NEXT: or a0, a1, a0 +; RV64ZBB-NEXT: rolw a0, a0, a1 ; RV64ZBB-NEXT: ret %z = sub i32 32, %y %b = shl i32 %x, %y @@ -79,11 +75,7 @@ ; ; RV64ZBB-LABEL: rotr_32: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: li a2, 32 -; RV64ZBB-NEXT: subw a2, a2, a1 -; RV64ZBB-NEXT: srlw a1, a0, a1 -; RV64ZBB-NEXT: sllw a0, a0, a2 -; RV64ZBB-NEXT: or a0, a1, a0 +; RV64ZBB-NEXT: rorw a0, a0, a1 ; RV64ZBB-NEXT: ret %z = sub i32 32, %y %b = lshr i32 %x, %y @@ -322,10 +314,7 @@ ; ; RV64ZBB-LABEL: rotl_32_mask: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: negw a2, a1 -; RV64ZBB-NEXT: sllw a1, a0, a1 -; RV64ZBB-NEXT: srlw a0, a0, a2 -; RV64ZBB-NEXT: or a0, a1, a0 +; RV64ZBB-NEXT: rolw a0, a0, a1 ; RV64ZBB-NEXT: ret %z = sub i32 0, %y %and = and i32 %z, 31 @@ -359,10 +348,7 @@ ; ; RV64ZBB-LABEL: rotr_32_mask: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: negw a2, a1 -; RV64ZBB-NEXT: srlw a1, a0, a1 -; RV64ZBB-NEXT: sllw a0, a0, a2 -; RV64ZBB-NEXT: or a0, a1, a0 +; RV64ZBB-NEXT: rorw a0, a0, a1 ; RV64ZBB-NEXT: ret %z = sub i32 0, %y %and = and i32 %z, 31