Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -36514,6 +36514,7 @@ // fold (or (x << c) | (y >> (64 - c))) ==> (shld64 x, y, c) bool OptForSize = DAG.getMachineFunction().getFunction().optForSize(); + unsigned Bits = VT.getScalarSizeInBits(); // SHLD/SHRD instructions have lower register pressure, but on some // platforms they have higher latency than the equivalent @@ -36536,6 +36537,23 @@ SDValue ShAmt1 = N1.getOperand(1); if (ShAmt1.getValueType() != MVT::i8) return SDValue(); + + // Peek through any modulo shift masks. + SDValue ShMsk0; + if (ShAmt0.getOpcode() == ISD::AND) + if (auto *ShMsk0Cst = dyn_cast(ShAmt0.getOperand(1))) + if (ShMsk0Cst->getAPIntValue() == (Bits - 1)) { + ShMsk0 = ShAmt0; + ShAmt0 = ShAmt0.getOperand(0); + } + SDValue ShMsk1; + if (ShAmt1.getOpcode() == ISD::AND) + if (auto *ShMsk1Cst = dyn_cast(ShAmt1.getOperand(1))) + if (ShMsk1Cst->getAPIntValue() == (Bits - 1)) { + ShMsk1 = ShAmt1; + ShAmt1 = ShAmt1.getOperand(0); + } + if (ShAmt0.getOpcode() == ISD::TRUNCATE) ShAmt0 = ShAmt0.getOperand(0); if (ShAmt1.getOpcode() == ISD::TRUNCATE) @@ -36550,24 +36568,26 @@ Opc = X86ISD::SHRD; std::swap(Op0, Op1); std::swap(ShAmt0, ShAmt1); + std::swap(ShMsk0, ShMsk1); } // OR( SHL( X, C ), SRL( Y, 32 - C ) ) -> SHLD( X, Y, C ) // OR( SRL( X, C ), SHL( Y, 32 - C ) ) -> SHRD( X, Y, C ) // OR( SHL( X, C ), SRL( SRL( Y, 1 ), XOR( C, 31 ) ) ) -> SHLD( X, Y, C ) // OR( SRL( X, C ), SHL( SHL( Y, 1 ), XOR( C, 31 ) ) ) -> SHRD( X, Y, C ) - unsigned Bits = VT.getScalarSizeInBits(); + // OR( SHL( X, AND( C, 31 ) ), SRL( Y, AND( 0 - C, 31 ) ) ) -> SHLD( X, Y, C ) + // OR( SRL( X, AND( C, 31 ) ), SHL( Y, AND( 0 - C, 31 ) ) ) -> SHRD( X, Y, C ) if (ShAmt1.getOpcode() == ISD::SUB) { SDValue Sum = ShAmt1.getOperand(0); if (auto *SumC = dyn_cast(Sum)) { SDValue ShAmt1Op1 = ShAmt1.getOperand(1); if (ShAmt1Op1.getOpcode() == ISD::TRUNCATE) ShAmt1Op1 = ShAmt1Op1.getOperand(0); - if (SumC->getSExtValue() == Bits && ShAmt1Op1 == ShAmt0) - return DAG.getNode(Opc, DL, VT, - Op0, Op1, - DAG.getNode(ISD::TRUNCATE, DL, - MVT::i8, ShAmt0)); + if ((SumC->getAPIntValue() == Bits || + (SumC->getAPIntValue() == 0 && ShMsk1)) && + ShAmt1Op1 == ShAmt0) + return DAG.getNode(Opc, DL, VT, Op0, Op1, + DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); } } else if (auto *ShAmt1C = dyn_cast(ShAmt1)) { auto *ShAmt0C = dyn_cast(ShAmt0); @@ -36583,7 +36603,8 @@ SDValue ShAmt1Op0 = ShAmt1.getOperand(0); if (ShAmt1Op0.getOpcode() == ISD::TRUNCATE) ShAmt1Op0 = ShAmt1Op0.getOperand(0); - if (MaskC->getSExtValue() == (Bits - 1) && ShAmt1Op0 == ShAmt0) { + if (MaskC->getSExtValue() == (Bits - 1) && + (ShAmt1Op0 == ShAmt0 || ShAmt1Op0 == ShMsk0)) { if (Op1.getOpcode() == InnerShift && isa(Op1.getOperand(1)) && Op1.getConstantOperandVal(1) == 1) { @@ -36594,7 +36615,7 @@ if (InnerShift == ISD::SHL && Op1.getOpcode() == ISD::ADD && Op1.getOperand(0) == Op1.getOperand(1)) { return DAG.getNode(Opc, DL, VT, Op0, Op1.getOperand(0), - DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); + DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); } } } Index: test/CodeGen/X86/shift-double.ll =================================================================== --- test/CodeGen/X86/shift-double.ll +++ test/CodeGen/X86/shift-double.ll @@ -460,24 +460,18 @@ define i32 @shld_safe_i32(i32, i32, i32) { ; X86-LABEL: shld_safe_i32: ; X86: # %bb.0: -; X86-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-NEXT: movb {{[0-9]+}}(%esp), %cl ; X86-NEXT: movl {{[0-9]+}}(%esp), %edx -; X86-NEXT: shll %cl, %edx -; X86-NEXT: negb %cl -; X86-NEXT: shrl %cl, %eax -; X86-NEXT: orl %edx, %eax +; X86-NEXT: movl {{[0-9]+}}(%esp), %eax +; X86-NEXT: shldl %cl, %edx, %eax ; X86-NEXT: retl ; ; X64-LABEL: shld_safe_i32: ; X64: # %bb.0: ; X64-NEXT: movl %edx, %ecx -; X64-NEXT: movl %esi, %eax -; X64-NEXT: shll %cl, %edi -; X64-NEXT: negb %cl +; X64-NEXT: movl %edi, %eax ; X64-NEXT: # kill: def $cl killed $cl killed $ecx -; X64-NEXT: shrl %cl, %eax -; X64-NEXT: orl %edi, %eax +; X64-NEXT: shldl %cl, %esi, %eax ; X64-NEXT: retq %4 = and i32 %2, 31 %5 = shl i32 %0, %4 @@ -491,24 +485,18 @@ define i32 @shrd_safe_i32(i32, i32, i32) { ; X86-LABEL: shrd_safe_i32: ; X86: # %bb.0: -; X86-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-NEXT: movb {{[0-9]+}}(%esp), %cl ; X86-NEXT: movl {{[0-9]+}}(%esp), %edx -; X86-NEXT: shrl %cl, %edx -; X86-NEXT: negb %cl -; X86-NEXT: shll %cl, %eax -; X86-NEXT: orl %edx, %eax +; X86-NEXT: movl {{[0-9]+}}(%esp), %eax +; X86-NEXT: shrdl %cl, %edx, %eax ; X86-NEXT: retl ; ; X64-LABEL: shrd_safe_i32: ; X64: # %bb.0: ; X64-NEXT: movl %edx, %ecx -; X64-NEXT: movl %esi, %eax -; X64-NEXT: shrl %cl, %edi -; X64-NEXT: negb %cl +; X64-NEXT: movl %edi, %eax ; X64-NEXT: # kill: def $cl killed $cl killed $ecx -; X64-NEXT: shll %cl, %eax -; X64-NEXT: orl %edi, %eax +; X64-NEXT: shrdl %cl, %esi, %eax ; X64-NEXT: retq %4 = and i32 %2, 31 %5 = lshr i32 %0, %4