diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3408,6 +3408,21 @@ return true; } + /// Return true if it's profitable to replace + /// + /// shift x, non-constant + /// + /// with two instances of + /// + /// shift x, constant + /// + /// where `shift` is a shift or rotate operation (not including funnel shift ops). + virtual bool + shiftOrRotateIsFasterWithConstantShiftAmount(const SDNode *N, + CombineLevel Level) const { + return false; + } + // Return true if it is profitable to combine a BUILD_VECTOR with a stride-pattern // to a shuffle and a truncate. // Example of such a combine: diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -497,6 +497,9 @@ SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, SDValue N1, SDNodeFlags Flags); + // SHL, SRA, SRL, RTOL, ROTR, but FSHL or FSHR. + SDValue visitShiftOrRotate(SDNode *N); + SDValue visitShiftByConstant(SDNode *N); SDValue foldSelectOfConstants(SDNode *N); @@ -7299,6 +7302,39 @@ return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2); } +SDValue DAGCombiner::visitShiftOrRotate(SDNode *N) { + auto ShiftOpcode = N->getOpcode(); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + // On some targets, shifting/rotating by a constant is faster than + // shifting/rotating by a register, so we fold: + // + // shift lhs, (select cond, constant1, constant2) --> + // select cond, (shift lhs, constant1), (shift lhs, constant2) + // + // Only do this after legalizing types. A tree which initially fits this + // pattern might be legalized to something quite different. E.g. `select + // cond, 2, 3` can be simplified to `add cond, 2`, and we don't want to + // transform that! + // + // TODO: This logic could be extended to ops other than shift/rotate. + if (OptLevel != CodeGenOpt::None && + Level >= AfterLegalizeTypes && + RHS.getOpcode() == ISD::SELECT && RHS.hasOneUse() && + isa(RHS.getOperand(1)) && + isa(RHS.getOperand(2)) && + TLI.shiftOrRotateIsFasterWithConstantShiftAmount(N, Level)) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + return DAG.getNode( + ISD::SELECT, DL, VT, RHS.getOperand(0), + DAG.getNode(ShiftOpcode, DL, VT, LHS, RHS.getOperand(1)), + DAG.getNode(ShiftOpcode, DL, VT, LHS, RHS.getOperand(2))); + } + return SDValue(); +} + /// Handle transforms common to the three shifts, when the shift amount is a /// constant. /// We are looking for: (shift being one of shl/sra/srl) @@ -7406,6 +7442,9 @@ EVT VT = N->getValueType(0); unsigned Bitsize = VT.getScalarSizeInBits(); + if (SDValue V = visitShiftOrRotate(N)) + return V; + // fold (rot x, 0) -> x if (isNullOrNullSplat(N1)) return N0; @@ -7466,6 +7505,9 @@ if (SDValue V = DAG.simplifyShift(N0, N1)) return V; + if (SDValue V = visitShiftOrRotate(N)) + return V; + EVT VT = N0.getValueType(); EVT ShiftVT = N1.getValueType(); unsigned OpSizeInBits = VT.getScalarSizeInBits(); @@ -7714,6 +7756,9 @@ if (SDValue V = DAG.simplifyShift(N0, N1)) return V; + if (SDValue V = visitShiftOrRotate(N)) + return V; + EVT VT = N0.getValueType(); unsigned OpSizeInBits = VT.getScalarSizeInBits(); @@ -7904,6 +7949,9 @@ if (SDValue V = DAG.simplifyShift(N0, N1)) return V; + if (SDValue V = visitShiftOrRotate(N)) + return V; + EVT VT = N0.getValueType(); unsigned OpSizeInBits = VT.getScalarSizeInBits(); diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -789,6 +789,9 @@ SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override; + bool shiftOrRotateIsFasterWithConstantShiftAmount( + const SDNode* N, CombineLevel Level) const override; + // Return true if it is profitable to combine a BUILD_VECTOR with a // stride-pattern to a shuffle and a truncate. // Example of such a combine: diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -46539,6 +46539,18 @@ return true; } +bool X86TargetLowering::shiftOrRotateIsFasterWithConstantShiftAmount( + const SDNode *N, CombineLevel Level) const { + // On most x86 chips, shifts/rotates by a constant are faster than + // shifts/rotates by a register. + auto Opcode = N->getOpcode(); + assert(Opcode == ISD::SHL || Opcode == ISD::SRA || Opcode == ISD::SRL || + Opcode == ISD::ROTL || Opcode == ISD::ROTR); + // Scalar shifts of an immediate are faster than scalar shifts of a register. + // But vector shifts have no such preference. + return !N->getValueType(0).isVector(); +} + bool X86TargetLowering:: isDesirableToCombineBuildVectorToShuffleTruncate( ArrayRef ShuffleMask, EVT SrcVT, EVT TruncVT) const { diff --git a/llvm/test/CodeGen/X86/dagcombine-shifts.ll b/llvm/test/CodeGen/X86/dagcombine-shifts.ll --- a/llvm/test/CodeGen/X86/dagcombine-shifts.ll +++ b/llvm/test/CodeGen/X86/dagcombine-shifts.ll @@ -215,3 +215,168 @@ declare void @f(i64) +; The *_select tests below check that we do the following transformation: +; +; shift lhs, (select cond, constant1, constant2) --> +; select cond, (shift lhs, constant1), (shift lhs, constant2) +; +; When updating these testcases, ensure that there are two shift instructions +; in the result and that they take immediates rather than registers. +define i32 @shl_select(i32 %x, i1 %cond) { +; CHECK-LABEL: shl_select: +; CHECK: # %bb.0: +; CHECK-NEXT: movl %edi, %eax +; CHECK-NEXT: movl %edi, %ecx +; CHECK-NEXT: shrl $3, %ecx +; CHECK-NEXT: shrl $6, %eax +; CHECK-NEXT: testb $1, %sil +; CHECK-NEXT: cmovnel %ecx, %eax +; CHECK-NEXT: retq + %shift_amnt = select i1 %cond, i32 3, i32 6 + %ret = lshr i32 %x, %shift_amnt + ret i32 %ret +} + +define i32 @ashr_select(i32 %x, i1 %cond) { +; CHECK-LABEL: ashr_select: +; CHECK: # %bb.0: +; CHECK-NEXT: movl %edi, %eax +; CHECK-NEXT: movl %edi, %ecx +; CHECK-NEXT: sarl $3, %ecx +; CHECK-NEXT: sarl $6, %eax +; CHECK-NEXT: testb $1, %sil +; CHECK-NEXT: cmovnel %ecx, %eax +; CHECK-NEXT: retq + %shift_amnt = select i1 %cond, i32 3, i32 6 + %ret = ashr i32 %x, %shift_amnt + ret i32 %ret +} + +define i32 @lshr_select(i32 %x, i1 %cond) { +; CHECK-LABEL: lshr_select: +; CHECK: # %bb.0: +; CHECK-NEXT: movl %edi, %eax +; CHECK-NEXT: movl %edi, %ecx +; CHECK-NEXT: shrl $3, %ecx +; CHECK-NEXT: shrl $6, %eax +; CHECK-NEXT: testb $1, %sil +; CHECK-NEXT: cmovnel %ecx, %eax +; CHECK-NEXT: retq + %shift_amnt = select i1 %cond, i32 3, i32 6 + %ret = lshr i32 %x, %shift_amnt + ret i32 %ret +} + +; Check that we don't perform the folding described in shl_select when the +; shift width is used other than as an input to the shift instruction. +; +; When updating this testcase, check that there's exactly one shrl instruction +; generated. +declare void @i32_foo(i32) +define i32 @shl_select_not_folded_if_shift_amnt_is_used(i32 %x, i1 %cond) { +; CHECK-LABEL: shl_select_not_folded_if_shift_amnt_is_used: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rbp +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: pushq %rbx +; CHECK-NEXT: .cfi_def_cfa_offset 24 +; CHECK-NEXT: pushq %rax +; CHECK-NEXT: .cfi_def_cfa_offset 32 +; CHECK-NEXT: .cfi_offset %rbx, -24 +; CHECK-NEXT: .cfi_offset %rbp, -16 +; CHECK-NEXT: movl %edi, %ebx +; CHECK-NEXT: notb %sil +; CHECK-NEXT: movzbl %sil, %eax +; CHECK-NEXT: andl $1, %eax +; CHECK-NEXT: leal 3(%rax,%rax,2), %ebp +; CHECK-NEXT: movl %ebp, %edi +; CHECK-NEXT: callq i32_foo +; CHECK-NEXT: movl %ebp, %ecx +; CHECK-NEXT: shrl %cl, %ebx +; CHECK-NEXT: movl %ebx, %eax +; CHECK-NEXT: addq $8, %rsp +; CHECK-NEXT: .cfi_def_cfa_offset 24 +; CHECK-NEXT: popq %rbx +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: popq %rbp +; CHECK-NEXT: .cfi_def_cfa_offset 8 +; CHECK-NEXT: retq + %shift_amnt = select i1 %cond, i32 3, i32 6 + call void @i32_foo(i32 %shift_amnt) + %ret = lshr i32 %x, %shift_amnt + ret i32 %ret +} + +; Check that we don't perfrm the folding described in shl_select when one of +; the shift widths is not a constant. +; +; When updating these testcases, check that there's exactly one shrl +; instruction generated in each. +define i32 @shl_select_not_folded_if_shift_amnt_is_nonconstant_1(i32 %x, i32 %a, i1 %cond) { +; CHECK-LABEL: shl_select_not_folded_if_shift_amnt_is_nonconstant_1: +; CHECK: # %bb.0: +; CHECK-NEXT: movl %edi, %eax +; CHECK-NEXT: testb $1, %dl +; CHECK-NEXT: movl $6, %ecx +; CHECK-NEXT: cmovnel %esi, %ecx +; CHECK-NEXT: # kill: def $cl killed $cl killed $ecx +; CHECK-NEXT: shrl %cl, %eax +; CHECK-NEXT: retq + %shift_amnt = select i1 %cond, i32 %a, i32 6 + %ret = lshr i32 %x, %shift_amnt + ret i32 %ret +} + +define i32 @shl_select_not_folded_if_shift_amnt_is_nonconstant_2(i32 %x, i32 %a, i1 %cond) { +; CHECK-LABEL: shl_select_not_folded_if_shift_amnt_is_nonconstant_2: +; CHECK: # %bb.0: +; CHECK-NEXT: movl %edi, %eax +; CHECK-NEXT: testb $1, %dl +; CHECK-NEXT: movl $3, %ecx +; CHECK-NEXT: cmovel %esi, %ecx +; CHECK-NEXT: # kill: def $cl killed $cl killed $ecx +; CHECK-NEXT: shrl %cl, %eax +; CHECK-NEXT: retq + %shift_amnt = select i1 %cond, i32 3, i32 %a + %ret = lshr i32 %x, %shift_amnt + ret i32 %ret +} + +define i32 @shl_select_not_folded_if_shift_amnt_is_nonconstant_3(i32 %x, i32 %a, i32 %b, i1 %cond) { +; CHECK-LABEL: shl_select_not_folded_if_shift_amnt_is_nonconstant_3: +; CHECK: # %bb.0: +; CHECK-NEXT: movl %edi, %eax +; CHECK-NEXT: testb $1, %cl +; CHECK-NEXT: cmovel %edx, %esi +; CHECK-NEXT: movl %esi, %ecx +; CHECK-NEXT: shrl %cl, %eax +; CHECK-NEXT: retq + %shift_amnt = select i1 %cond, i32 %a, i32 %b + %ret = lshr i32 %x, %shift_amnt + ret i32 %ret +} + +; Check that we don't perform the folding described in shl_select when the +; operand is a vector, because x86 vector shifts don't go faster when the shift +; width is a known constant. +; +; When updating this testcase, check that there's exactly one shift instruction +; generated. +define <4 x i32> @shr_select_not_folded_for_vector_shifts1(<4 x i32> %x, i1 %cond) { +; CHECK-LABEL: shr_select_not_folded_for_vector_shifts1: +; CHECK: # %bb.0: +; CHECK-NEXT: notb %dil +; CHECK-NEXT: movzbl %dil, %eax +; CHECK-NEXT: andl $1, %eax +; CHECK-NEXT: leal 3(%rax,%rax,2), %eax +; CHECK-NEXT: movd %eax, %xmm1 +; CHECK-NEXT: psrld %xmm1, %xmm0 +; CHECK-NEXT: retq + %shift_amnt = select i1 %cond, i32 3, i32 6 + %vshift_amnt0 = insertelement <4 x i32> undef, i32 %shift_amnt, i32 0 + %vshift_amnt1 = insertelement <4 x i32> %vshift_amnt0, i32 %shift_amnt, i32 1 + %vshift_amnt2 = insertelement <4 x i32> %vshift_amnt1, i32 %shift_amnt, i32 2 + %vshift_amnt3 = insertelement <4 x i32> %vshift_amnt2, i32 %shift_amnt, i32 3 + %ret = lshr <4 x i32> %x, %vshift_amnt3 + ret <4 x i32> %ret +}