Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -540,11 +540,13 @@ SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, bool DemandHighBits = true); SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1); - SDNode *MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg, + SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg, SDValue InnerPos, SDValue InnerNeg, unsigned PosOpcode, unsigned NegOpcode, const SDLoc &DL); - SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); + SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); + SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL, + const APInt &DemandedBits); SDValue MatchLoadCombine(SDNode *N); SDValue MatchStoreCombine(StoreSDNode *N); SDValue ReduceLoadWidth(SDNode *N); @@ -6062,8 +6064,8 @@ return V; // See if this is some rotate idiom. - if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N))) - return SDValue(Rot, 0); + if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N))) + return Rot; if (SDValue Load = MatchLoadCombine(N)) return Load; @@ -6348,7 +6350,7 @@ // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the // former being preferred if supported. InnerPos and InnerNeg are Pos and // Neg with outer conversions stripped away. -SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, +SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg, SDValue InnerPos, SDValue InnerNeg, unsigned PosOpcode, unsigned NegOpcode, const SDLoc &DL) { @@ -6363,32 +6365,39 @@ if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG)) { bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted, - HasPos ? Pos : Neg).getNode(); + HasPos ? Pos : Neg); } - return nullptr; + return SDValue(); +} + +SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { + return MatchRotate(LHS, RHS, DL, + APInt::getMaxValue(LHS.getValueType().getSizeInBits())); } // MatchRotate - Handle an 'or' of two operands. If this is one of the many // idioms for rotate, and if the target supports rotation instructions, generate // a rot[lr]. -SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { +SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL, + const APInt &DemandedBits) { // Must be a legal type. Expanded 'n promoted things won't work with rotates. EVT VT = LHS.getValueType(); - if (!TLI.isTypeLegal(VT)) return nullptr; + if (!TLI.isTypeLegal(VT)) + return SDValue(); // The target must have at least one rotate flavor. bool HasROTL = hasOperation(ISD::ROTL, VT); bool HasROTR = hasOperation(ISD::ROTR, VT); - if (!HasROTL && !HasROTR) return nullptr; + if (!HasROTL && !HasROTR) + return SDValue(); // Check for truncated rotate. if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE && LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) { assert(LHS.getValueType() == RHS.getValueType()); - if (SDNode *Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) { - return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), - SDValue(Rot, 0)).getNode(); + if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) { + return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot); } } @@ -6403,7 +6412,7 @@ // If neither side matched a rotate half, bail if (!LHSShift && !RHSShift) - return nullptr; + return SDValue(); // InstCombine may have combined a constant shl, srl, mul, or udiv with one // side of the rotate, so try to handle that here. In all cases we need to @@ -6426,15 +6435,15 @@ // If a side is still missing, nothing else we can do. if (!RHSShift || !LHSShift) - return nullptr; + return SDValue(); // At this point we've matched or extracted a shift op on each side. if (LHSShift.getOperand(0) != RHSShift.getOperand(0)) - return nullptr; // Not shifting the same value. + return SDValue(); // Not shifting the same value. if (LHSShift.getOpcode() == RHSShift.getOpcode()) - return nullptr; // Shifts must disagree. + return SDValue(); // Shifts must disagree. // Canonicalize shl to left side in a shl/srl pair. if (RHSShift.getOpcode() == ISD::SHL) { @@ -6443,21 +6452,41 @@ std::swap(LHSMask, RHSMask); } - unsigned EltSizeInBits = VT.getScalarSizeInBits(); SDValue LHSShiftArg = LHSShift.getOperand(0); SDValue LHSShiftAmt = LHSShift.getOperand(1); SDValue RHSShiftArg = RHSShift.getOperand(0); SDValue RHSShiftAmt = RHSShift.getOperand(1); + EVT RotVT = VT; + // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) - auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits; + auto MatchRotateSum = [this, &DemandedBits, &RotVT](ConstantSDNode *LHS, + ConstantSDNode *RHS) { + uint64_t RotAmount = + (LHS->getAPIntValue() + RHS->getAPIntValue()).getZExtValue(); + // For vectors, only allow exact match. + if (RotVT.isVector()) + return RotAmount == RotVT.getScalarSizeInBits(); + // For scalar, check that the type we use for rotation cover all demanded + // bits and is legal. + APInt RotMask = + APInt::getMaxValue(RotAmount).zextOrTrunc(DemandedBits.getBitWidth()); + if ((DemandedBits & RotMask) != DemandedBits) + return false; + RotVT = EVT::getIntegerVT(*DAG.getContext(), RotAmount); + return TLI.isTypeLegal(RotVT); }; if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { - SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, - LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt); + HasROTL = hasOperation(ISD::ROTL, RotVT); + HasROTR = hasOperation(ISD::ROTR, RotVT); + if (!HasROTL && !HasROTR) + return SDValue(); + + SDValue Rotated = DAG.getZExtOrTrunc(LHSShiftArg, DL, RotVT); + SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, RotVT, + Rotated, HasROTL ? LHSShiftAmt : RHSShiftAmt); + Rot = DAG.getAnyExtOrTrunc(Rot, DL, VT); // If there is an AND of either shifted operand, apply it to the result. if (LHSMask.getNode() || RHSMask.getNode()) { @@ -6478,13 +6507,13 @@ Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask); } - return Rot.getNode(); + return Rot; } // If there is a mask here, and we have a variable shift, we can't be sure // that we're masking out the right stuff. if (LHSMask.getNode() || RHSMask.getNode()) - return nullptr; + return SDValue(); // If the shift amount is sign/zext/any-extended just peel it off. SDValue LExtOp0 = LHSShiftAmt; @@ -6501,17 +6530,17 @@ RExtOp0 = RHSShiftAmt.getOperand(0); } - SDNode *TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, + SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0, RExtOp0, ISD::ROTL, ISD::ROTR, DL); if (TryL) return TryL; - SDNode *TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, + SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0, LExtOp0, ISD::ROTR, ISD::ROTL, DL); if (TryR) return TryR; - return nullptr; + return SDValue(); } namespace { @@ -10925,11 +10954,20 @@ // because targets may prefer a wider type during later combines and invert // this transform. switch (N0.getOpcode()) { + case ISD::OR: { + // TODO: This would idealy be part of the SimplifyDemandedBits mechanic, but + // there are no way to easily plug it in at the moment, so it is limited to + // TRUNC. + SDLoc DL(N); + if (SDValue Rot = MatchRotate(N0.getOperand(0), N0.getOperand(1), DL, + APInt::getMaxValue(VT.getSizeInBits()))) + return DAG.getNode(ISD::TRUNCATE, DL, VT, Rot); + LLVM_FALLTHROUGH; + } case ISD::ADD: case ISD::SUB: case ISD::MUL: case ISD::AND: - case ISD::OR: case ISD::XOR: if (!LegalOperations && N0.hasOneUse() && (isConstantOrConstantVector(N0.getOperand(0), true) || Index: test/CodeGen/X86/rot16.ll =================================================================== --- test/CodeGen/X86/rot16.ll +++ test/CodeGen/X86/rot16.ll @@ -208,21 +208,14 @@ define i16 @rot16_trunc(i32 %x, i32 %y) nounwind { ; X32-LABEL: rot16_trunc: ; X32: # %bb.0: -; X32-NEXT: movl {{[0-9]+}}(%esp), %eax -; X32-NEXT: movl %eax, %ecx -; X32-NEXT: shrl $11, %ecx -; X32-NEXT: shll $5, %eax -; X32-NEXT: orl %ecx, %eax -; X32-NEXT: # kill: def $ax killed $ax killed $eax +; X32-NEXT: movzwl {{[0-9]+}}(%esp), %eax +; X32-NEXT: rolw $5, %ax ; X32-NEXT: retl ; ; X64-LABEL: rot16_trunc: ; X64: # %bb.0: ; X64-NEXT: movl %edi, %eax -; X64-NEXT: movl %edi, %ecx -; X64-NEXT: shrl $11, %ecx -; X64-NEXT: shll $5, %eax -; X64-NEXT: orl %ecx, %eax +; X64-NEXT: rolw $5, %ax ; X64-NEXT: # kill: def $ax killed $ax killed $eax ; X64-NEXT: retq %t0 = lshr i32 %x, 11