diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -3391,16 +3391,26 @@ return false; SDValue NBits; + bool NegateNBits; // If we have BMI2's BZHI, we are ok with muti-use patterns. // Else, if we only have BMI1's BEXTR, we require one-use. - const bool CanHaveExtraUses = Subtarget->hasBMI2(); - auto checkUses = [CanHaveExtraUses](SDValue Op, unsigned NUses) { - return CanHaveExtraUses || + const bool AllowExtraUsesByDefault = Subtarget->hasBMI2(); + auto checkUses = [AllowExtraUsesByDefault](SDValue Op, unsigned NUses, + Optional AllowExtraUses) { + if (AllowExtraUses == None) + AllowExtraUses = AllowExtraUsesByDefault; + return *AllowExtraUses || Op.getNode()->hasNUsesOfValue(NUses, Op.getResNo()); }; - auto checkOneUse = [checkUses](SDValue Op) { return checkUses(Op, 1); }; - auto checkTwoUse = [checkUses](SDValue Op) { return checkUses(Op, 2); }; + auto checkOneUse = [checkUses](SDValue Op, + Optional AllowExtraUses = None) { + return checkUses(Op, 1, AllowExtraUses); + }; + auto checkTwoUse = [checkUses](SDValue Op, + Optional AllowExtraUses = None) { + return checkUses(Op, 2, AllowExtraUses); + }; auto peekThroughOneUseTruncation = [checkOneUse](SDValue V) { if (V->getOpcode() == ISD::TRUNCATE && checkOneUse(V)) { @@ -3413,8 +3423,8 @@ }; // a) x & ((1 << nbits) + (-1)) - auto matchPatternA = [checkOneUse, peekThroughOneUseTruncation, - &NBits](SDValue Mask) -> bool { + auto matchPatternA = [checkOneUse, peekThroughOneUseTruncation, &NBits, + &NegateNBits](SDValue Mask) -> bool { // Match `add`. Must only have one use! if (Mask->getOpcode() != ISD::ADD || !checkOneUse(Mask)) return false; @@ -3428,6 +3438,7 @@ if (!isOneConstant(M0->getOperand(0))) return false; NBits = M0->getOperand(1); + NegateNBits = false; return true; }; @@ -3440,7 +3451,7 @@ // b) x & ~(-1 << nbits) auto matchPatternB = [checkOneUse, isAllOnes, peekThroughOneUseTruncation, - &NBits](SDValue Mask) -> bool { + &NBits, &NegateNBits](SDValue Mask) -> bool { // Match `~()`. Must only have one use! if (Mask.getOpcode() != ISD::XOR || !checkOneUse(Mask)) return false; @@ -3455,32 +3466,35 @@ if (!isAllOnes(M0->getOperand(0))) return false; NBits = M0->getOperand(1); + NegateNBits = false; return true; }; - // Match potentially-truncated (bitwidth - y) - auto matchShiftAmt = [checkOneUse, &NBits](SDValue ShiftAmt, - unsigned Bitwidth) { - // Skip over a truncate of the shift amount. - if (ShiftAmt.getOpcode() == ISD::TRUNCATE) { - ShiftAmt = ShiftAmt.getOperand(0); - // The trunc should have been the only user of the real shift amount. - if (!checkOneUse(ShiftAmt)) - return false; - } - // Match the shift amount as: (bitwidth - y). It should go away, too. - if (ShiftAmt.getOpcode() != ISD::SUB) - return false; - auto *V0 = dyn_cast(ShiftAmt.getOperand(0)); + // Try to match potentially-truncated shift amount as `(bitwidth - y)`, + // or leave the shift amount as-is, but then we'll have to negate it. + auto canonicalizeShiftAmt = [&NBits, &NegateNBits](SDValue ShiftAmt, + unsigned Bitwidth) { + NBits = ShiftAmt; + NegateNBits = true; + // Skip over a truncate of the shift amount, if any. + if (NBits.getOpcode() == ISD::TRUNCATE) + NBits = NBits.getOperand(0); + // Try to match the shift amount as (bitwidth - y). It should go away, too. + // If it doesn't match, that's fine, we'll just negate it ourselves. + if (NBits.getOpcode() != ISD::SUB) + return; + auto *V0 = dyn_cast(NBits.getOperand(0)); if (!V0 || V0->getZExtValue() != Bitwidth) - return false; - NBits = ShiftAmt.getOperand(1); - return true; + return; + NBits = NBits.getOperand(1); + NegateNBits = false; }; + // c) x & (-1 >> z) but then we'll have to subtract z from bitwidth + // or // c) x & (-1 >> (32 - y)) - auto matchPatternC = [checkOneUse, peekThroughOneUseTruncation, - matchShiftAmt](SDValue Mask) -> bool { + auto matchPatternC = [checkOneUse, peekThroughOneUseTruncation, &NegateNBits, + canonicalizeShiftAmt](SDValue Mask) -> bool { // The mask itself may be truncated. Mask = peekThroughOneUseTruncation(Mask); unsigned Bitwidth = Mask.getSimpleValueType().getSizeInBits(); @@ -3494,27 +3508,39 @@ // The shift amount should not be used externally. if (!checkOneUse(M1)) return false; - return matchShiftAmt(M1, Bitwidth); + canonicalizeShiftAmt(M1, Bitwidth); + // Pattern c. is non-canonical, and is expanded into pattern d. iff there + // is no extra use of the mask. Clearly, there was one since we are here. + // But at the same time, if we need to negate the shift amount, + // then we don't want the mask to stick around, else it's unprofitable. + return !NegateNBits; }; SDValue X; + // d) x << z >> z but then we'll have to subtract z from bitwidth + // or // d) x << (32 - y) >> (32 - y) - auto matchPatternD = [checkOneUse, checkTwoUse, matchShiftAmt, + auto matchPatternD = [checkOneUse, checkTwoUse, canonicalizeShiftAmt, + AllowExtraUsesByDefault, &NegateNBits, &X](SDNode *Node) -> bool { if (Node->getOpcode() != ISD::SRL) return false; SDValue N0 = Node->getOperand(0); - if (N0->getOpcode() != ISD::SHL || !checkOneUse(N0)) + if (N0->getOpcode() != ISD::SHL) return false; unsigned Bitwidth = N0.getSimpleValueType().getSizeInBits(); SDValue N1 = Node->getOperand(1); SDValue N01 = N0->getOperand(1); // Both of the shifts must be by the exact same value. - // There should not be any uses of the shift amount outside of the pattern. - if (N1 != N01 || !checkTwoUse(N1)) + if (N1 != N01) return false; - if (!matchShiftAmt(N1, Bitwidth)) + canonicalizeShiftAmt(N1, Bitwidth); + // There should not be any external uses of the inner shift / shift amount. + // Note that while we are generally okay with external uses given BMI2, + // iff we need to negate the shift amount, we are not okay with extra uses. + const bool AllowExtraUses = AllowExtraUsesByDefault && !NegateNBits; + if (!checkOneUse(N0, AllowExtraUses) || !checkTwoUse(N1, AllowExtraUses)) return false; X = N0->getOperand(0); return true; @@ -3539,6 +3565,11 @@ } else if (!matchPatternD(Node)) return false; + // If we need to negate the shift amount, require BMI2 BZHI support. + // It's just too unprofitable for BMI1 BEXTR. + if (NegateNBits && !Subtarget->hasBMI2()) + return false; + SDLoc DL(Node); // Truncate the shift amount. @@ -3553,11 +3584,21 @@ SDValue SRIdxVal = CurDAG->getTargetConstant(X86::sub_8bit, DL, MVT::i32); insertDAGNode(*CurDAG, SDValue(Node, 0), SRIdxVal); - NBits = SDValue( - CurDAG->getMachineNode(TargetOpcode::INSERT_SUBREG, DL, MVT::i32, ImplDef, - NBits, SRIdxVal), 0); + NBits = SDValue(CurDAG->getMachineNode(TargetOpcode::INSERT_SUBREG, DL, + MVT::i32, ImplDef, NBits, SRIdxVal), + 0); insertDAGNode(*CurDAG, SDValue(Node, 0), NBits); + // We might have matched the amount of high bits to be cleared, + // but we want the amount of low bits to be kept, so negate it then. + if (NegateNBits) { + SDValue BitWidthC = CurDAG->getConstant(NVT.getSizeInBits(), DL, MVT::i32); + insertDAGNode(*CurDAG, SDValue(Node, 0), BitWidthC); + + NBits = CurDAG->getNode(ISD::SUB, DL, MVT::i32, BitWidthC, NBits); + insertDAGNode(*CurDAG, SDValue(Node, 0), NBits); + } + if (Subtarget->hasBMI2()) { // Great, just emit the the BZHI.. if (NVT != MVT::i32) { diff --git a/llvm/test/CodeGen/X86/clear-highbits.ll b/llvm/test/CodeGen/X86/clear-highbits.ll --- a/llvm/test/CodeGen/X86/clear-highbits.ll +++ b/llvm/test/CodeGen/X86/clear-highbits.ll @@ -335,8 +335,9 @@ ; X86-BMI2-LABEL: clear_highbits32_c0: ; X86-BMI2: # %bb.0: ; X86-BMI2-NEXT: movb {{[0-9]+}}(%esp), %al -; X86-BMI2-NEXT: shlxl %eax, {{[0-9]+}}(%esp), %ecx -; X86-BMI2-NEXT: shrxl %eax, %ecx, %eax +; X86-BMI2-NEXT: movl $32, %ecx +; X86-BMI2-NEXT: subl %eax, %ecx +; X86-BMI2-NEXT: bzhil %ecx, {{[0-9]+}}(%esp), %eax ; X86-BMI2-NEXT: retl ; ; X64-NOBMI2-LABEL: clear_highbits32_c0: @@ -350,8 +351,9 @@ ; ; X64-BMI2-LABEL: clear_highbits32_c0: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: shlxl %esi, %edi, %eax -; X64-BMI2-NEXT: shrxl %esi, %eax, %eax +; X64-BMI2-NEXT: movl $32, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhil %eax, %edi, %eax ; X64-BMI2-NEXT: retq %mask = lshr i32 -1, %numhighbits %masked = and i32 %mask, %val @@ -370,8 +372,9 @@ ; X86-BMI2-LABEL: clear_highbits32_c1_indexzext: ; X86-BMI2: # %bb.0: ; X86-BMI2-NEXT: movb {{[0-9]+}}(%esp), %al -; X86-BMI2-NEXT: shlxl %eax, {{[0-9]+}}(%esp), %ecx -; X86-BMI2-NEXT: shrxl %eax, %ecx, %eax +; X86-BMI2-NEXT: movl $32, %ecx +; X86-BMI2-NEXT: subl %eax, %ecx +; X86-BMI2-NEXT: bzhil %ecx, {{[0-9]+}}(%esp), %eax ; X86-BMI2-NEXT: retl ; ; X64-NOBMI2-LABEL: clear_highbits32_c1_indexzext: @@ -385,8 +388,9 @@ ; ; X64-BMI2-LABEL: clear_highbits32_c1_indexzext: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: shlxl %esi, %edi, %eax -; X64-BMI2-NEXT: shrxl %esi, %eax, %eax +; X64-BMI2-NEXT: movl $32, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhil %eax, %edi, %eax ; X64-BMI2-NEXT: retq %sh_prom = zext i8 %numhighbits to i32 %mask = lshr i32 -1, %sh_prom @@ -408,8 +412,9 @@ ; X86-BMI2: # %bb.0: ; X86-BMI2-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-BMI2-NEXT: movb {{[0-9]+}}(%esp), %cl -; X86-BMI2-NEXT: shlxl %ecx, (%eax), %eax -; X86-BMI2-NEXT: shrxl %ecx, %eax, %eax +; X86-BMI2-NEXT: movl $32, %edx +; X86-BMI2-NEXT: subl %ecx, %edx +; X86-BMI2-NEXT: bzhil %edx, (%eax), %eax ; X86-BMI2-NEXT: retl ; ; X64-NOBMI2-LABEL: clear_highbits32_c2_load: @@ -423,8 +428,9 @@ ; ; X64-BMI2-LABEL: clear_highbits32_c2_load: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: shlxl %esi, (%rdi), %eax -; X64-BMI2-NEXT: shrxl %esi, %eax, %eax +; X64-BMI2-NEXT: movl $32, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhil %eax, (%rdi), %eax ; X64-BMI2-NEXT: retq %val = load i32, i32* %w %mask = lshr i32 -1, %numhighbits @@ -446,8 +452,9 @@ ; X86-BMI2: # %bb.0: ; X86-BMI2-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-BMI2-NEXT: movb {{[0-9]+}}(%esp), %cl -; X86-BMI2-NEXT: shlxl %ecx, (%eax), %eax -; X86-BMI2-NEXT: shrxl %ecx, %eax, %eax +; X86-BMI2-NEXT: movl $32, %edx +; X86-BMI2-NEXT: subl %ecx, %edx +; X86-BMI2-NEXT: bzhil %edx, (%eax), %eax ; X86-BMI2-NEXT: retl ; ; X64-NOBMI2-LABEL: clear_highbits32_c3_load_indexzext: @@ -461,8 +468,9 @@ ; ; X64-BMI2-LABEL: clear_highbits32_c3_load_indexzext: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: shlxl %esi, (%rdi), %eax -; X64-BMI2-NEXT: shrxl %esi, %eax, %eax +; X64-BMI2-NEXT: movl $32, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhil %eax, (%rdi), %eax ; X64-BMI2-NEXT: retq %val = load i32, i32* %w %sh_prom = zext i8 %numhighbits to i32 @@ -483,8 +491,9 @@ ; X86-BMI2-LABEL: clear_highbits32_c4_commutative: ; X86-BMI2: # %bb.0: ; X86-BMI2-NEXT: movb {{[0-9]+}}(%esp), %al -; X86-BMI2-NEXT: shlxl %eax, {{[0-9]+}}(%esp), %ecx -; X86-BMI2-NEXT: shrxl %eax, %ecx, %eax +; X86-BMI2-NEXT: movl $32, %ecx +; X86-BMI2-NEXT: subl %eax, %ecx +; X86-BMI2-NEXT: bzhil %ecx, {{[0-9]+}}(%esp), %eax ; X86-BMI2-NEXT: retl ; ; X64-NOBMI2-LABEL: clear_highbits32_c4_commutative: @@ -498,8 +507,9 @@ ; ; X64-BMI2-LABEL: clear_highbits32_c4_commutative: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: shlxl %esi, %edi, %eax -; X64-BMI2-NEXT: shrxl %esi, %eax, %eax +; X64-BMI2-NEXT: movl $32, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhil %eax, %edi, %eax ; X64-BMI2-NEXT: retq %mask = lshr i32 -1, %numhighbits %masked = and i32 %val, %mask ; swapped order @@ -574,8 +584,9 @@ ; ; X64-BMI2-LABEL: clear_highbits64_c0: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: shlxq %rsi, %rdi, %rax -; X64-BMI2-NEXT: shrxq %rsi, %rax, %rax +; X64-BMI2-NEXT: movl $64, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhiq %rax, %rdi, %rax ; X64-BMI2-NEXT: retq %mask = lshr i64 -1, %numhighbits %masked = and i64 %mask, %val @@ -646,9 +657,9 @@ ; ; X64-BMI2-LABEL: clear_highbits64_c1_indexzext: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: # kill: def $esi killed $esi def $rsi -; X64-BMI2-NEXT: shlxq %rsi, %rdi, %rax -; X64-BMI2-NEXT: shrxq %rsi, %rax, %rax +; X64-BMI2-NEXT: movl $64, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhiq %rax, %rdi, %rax ; X64-BMI2-NEXT: retq %sh_prom = zext i8 %numhighbits to i64 %mask = lshr i64 -1, %sh_prom @@ -729,8 +740,9 @@ ; ; X64-BMI2-LABEL: clear_highbits64_c2_load: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: shlxq %rsi, (%rdi), %rax -; X64-BMI2-NEXT: shrxq %rsi, %rax, %rax +; X64-BMI2-NEXT: movl $64, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhiq %rax, (%rdi), %rax ; X64-BMI2-NEXT: retq %val = load i64, i64* %w %mask = lshr i64 -1, %numhighbits @@ -811,9 +823,9 @@ ; ; X64-BMI2-LABEL: clear_highbits64_c3_load_indexzext: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: # kill: def $esi killed $esi def $rsi -; X64-BMI2-NEXT: shlxq %rsi, (%rdi), %rax -; X64-BMI2-NEXT: shrxq %rsi, %rax, %rax +; X64-BMI2-NEXT: movl $64, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhiq %rax, (%rdi), %rax ; X64-BMI2-NEXT: retq %val = load i64, i64* %w %sh_prom = zext i8 %numhighbits to i64 @@ -886,8 +898,9 @@ ; ; X64-BMI2-LABEL: clear_highbits64_c4_commutative: ; X64-BMI2: # %bb.0: -; X64-BMI2-NEXT: shlxq %rsi, %rdi, %rax -; X64-BMI2-NEXT: shrxq %rsi, %rax, %rax +; X64-BMI2-NEXT: movl $64, %eax +; X64-BMI2-NEXT: subl %esi, %eax +; X64-BMI2-NEXT: bzhiq %rax, %rdi, %rax ; X64-BMI2-NEXT: retq %mask = lshr i64 -1, %numhighbits %masked = and i64 %val, %mask ; swapped order @@ -1217,8 +1230,9 @@ ; X86-BMI2: # %bb.0: ; X86-BMI2-NEXT: movb $16, %al ; X86-BMI2-NEXT: subb {{[0-9]+}}(%esp), %al -; X86-BMI2-NEXT: shlxl %eax, {{[0-9]+}}(%esp), %ecx -; X86-BMI2-NEXT: shrxl %eax, %ecx, %eax +; X86-BMI2-NEXT: movl $32, %ecx +; X86-BMI2-NEXT: subl %eax, %ecx +; X86-BMI2-NEXT: bzhil %ecx, {{[0-9]+}}(%esp), %eax ; X86-BMI2-NEXT: retl ; ; X64-NOBMI2-LABEL: clear_highbits32_16: @@ -1234,8 +1248,9 @@ ; X64-BMI2: # %bb.0: ; X64-BMI2-NEXT: movb $16, %al ; X64-BMI2-NEXT: subb %sil, %al -; X64-BMI2-NEXT: shlxl %eax, %edi, %ecx -; X64-BMI2-NEXT: shrxl %eax, %ecx, %eax +; X64-BMI2-NEXT: movl $32, %ecx +; X64-BMI2-NEXT: subl %eax, %ecx +; X64-BMI2-NEXT: bzhil %ecx, %edi, %eax ; X64-BMI2-NEXT: retq %numhighbits = sub i32 16, %numlowbits %mask = lshr i32 -1, %numhighbits @@ -1256,8 +1271,9 @@ ; X86-BMI2: # %bb.0: ; X86-BMI2-NEXT: movb $48, %al ; X86-BMI2-NEXT: subb {{[0-9]+}}(%esp), %al -; X86-BMI2-NEXT: shlxl %eax, {{[0-9]+}}(%esp), %ecx -; X86-BMI2-NEXT: shrxl %eax, %ecx, %eax +; X86-BMI2-NEXT: movl $32, %ecx +; X86-BMI2-NEXT: subl %eax, %ecx +; X86-BMI2-NEXT: bzhil %ecx, {{[0-9]+}}(%esp), %eax ; X86-BMI2-NEXT: retl ; ; X64-NOBMI2-LABEL: clear_highbits32_48: @@ -1273,8 +1289,9 @@ ; X64-BMI2: # %bb.0: ; X64-BMI2-NEXT: movb $48, %al ; X64-BMI2-NEXT: subb %sil, %al -; X64-BMI2-NEXT: shlxl %eax, %edi, %ecx -; X64-BMI2-NEXT: shrxl %eax, %ecx, %eax +; X64-BMI2-NEXT: movl $32, %ecx +; X64-BMI2-NEXT: subl %eax, %ecx +; X64-BMI2-NEXT: bzhil %ecx, %edi, %eax ; X64-BMI2-NEXT: retq %numhighbits = sub i32 48, %numlowbits %mask = lshr i32 -1, %numhighbits