diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3254,13 +3254,20 @@ /// constant integer. If so, check to see if there are any bits set in the /// constant that are not demanded. If so, shrink the constant and return /// true. - bool ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + bool ShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, + TargetLoweringOpt &TLO) const; + + /// Helper wrapper around ShrinkDemandedConstant, demanding all elements. + bool ShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, TargetLoweringOpt &TLO) const; // Target hook to do target-specific const optimization, which is called by // ShrinkDemandedConstant. This function should return true if the target // doesn't want ShrinkDemandedConstant to further optimize the constant. - virtual bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + virtual bool targetShrinkDemandedConstant(SDValue Op, + const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const { return false; } diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -483,13 +483,15 @@ /// If the specified instruction has a constant integer operand and there are /// bits set in that constant that are not demanded, then clear those bits and /// return true. -bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, +bool TargetLowering::ShrinkDemandedConstant(SDValue Op, + const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const { SDLoc DL(Op); unsigned Opcode = Op.getOpcode(); // Do target-specific constant optimization. - if (targetShrinkDemandedConstant(Op, Demanded, TLO)) + if (targetShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return TLO.New.getNode(); // FIXME: ISD::SELECT, ISD::SELECT_CC @@ -505,12 +507,12 @@ // If this is a 'not' op, don't touch it because that's a canonical form. const APInt &C = Op1C->getAPIntValue(); - if (Opcode == ISD::XOR && Demanded.isSubsetOf(C)) + if (Opcode == ISD::XOR && DemandedBits.isSubsetOf(C)) return false; - if (!C.isSubsetOf(Demanded)) { + if (!C.isSubsetOf(DemandedBits)) { EVT VT = Op.getValueType(); - SDValue NewC = TLO.DAG.getConstant(Demanded & C, DL, VT); + SDValue NewC = TLO.DAG.getConstant(DemandedBits & C, DL, VT); SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC); return TLO.CombineTo(Op, NewOp); } @@ -522,6 +524,16 @@ return false; } +bool TargetLowering::ShrinkDemandedConstant(SDValue Op, + const APInt &DemandedBits, + TargetLoweringOpt &TLO) const { + EVT VT = Op.getValueType(); + APInt DemandedElts = VT.isVector() + ? APInt::getAllOnesValue(VT.getVectorNumElements()) + : APInt(1, 1); + return ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO); +} + /// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free. /// This uses isZExtFree and ZERO_EXTEND for the widening cast, but it could be /// generalized for targets with other types of implicit widening casts. @@ -1171,7 +1183,8 @@ // If any of the set bits in the RHS are known zero on the LHS, shrink // the constant. - if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits, + DemandedElts, TLO)) return true; // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its @@ -1219,7 +1232,8 @@ if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero)) return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT)); // If the RHS is a constant, see if we can simplify it. - if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, DemandedElts, + TLO)) return true; // If the operation can be done in a smaller type, do so. if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) @@ -1262,7 +1276,7 @@ if (DemandedBits.isSubsetOf(Known.One | Known2.Zero)) return TLO.CombineTo(Op, Op1); // If the RHS is a constant, see if we can simplify it. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; // If the operation can be done in a smaller type, do so. if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) @@ -1314,7 +1328,8 @@ if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero)) return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1)); - if (ConstantSDNode *C = isConstOrConstSplat(Op1)) { + ConstantSDNode* C = isConstOrConstSplat(Op1, DemandedElts); + if (C) { // If one side is a constant, and all of the known set bits on the other // side are also set in the constant, turn this into an AND, as we know // the bits will be cleared. @@ -1329,18 +1344,19 @@ // If the RHS is a constant, see if we can change it. Don't alter a -1 // constant because that's a 'not' op, and that is better for combining // and codegen. - if (!C->isAllOnesValue()) { - if (DemandedBits.isSubsetOf(C->getAPIntValue())) { - // We're flipping all demanded bits. Flip the undemanded bits too. - SDValue New = TLO.DAG.getNOT(dl, Op0, VT); - return TLO.CombineTo(Op, New); - } - // If we can't turn this into a 'not', try to shrink the constant. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) - return true; + if (!C->isAllOnesValue() && + DemandedBits.isSubsetOf(C->getAPIntValue())) { + // We're flipping all demanded bits. Flip the undemanded bits too. + SDValue New = TLO.DAG.getNOT(dl, Op0, VT); + return TLO.CombineTo(Op, New); } } + // If we can't turn this into a 'not', try to shrink the constant. + if (!C || !C->isAllOnesValue()) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) + return true; + Known ^= Known2; break; } @@ -1355,7 +1371,7 @@ assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; // Only known if known in both the LHS and RHS. @@ -1373,7 +1389,7 @@ assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; // Only known if known in both the LHS and RHS. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -401,7 +401,8 @@ return MVT::getIntegerVT(64); } - bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + bool targetShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const override; MVT getScalarShiftAmountTy(const DataLayout &DL, EVT) const override; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1125,7 +1125,8 @@ } bool AArch64TargetLowering::targetShrinkDemandedConstant( - SDValue Op, const APInt &Demanded, TargetLoweringOpt &TLO) const { + SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, + TargetLoweringOpt &TLO) const { // Delay this optimization to as late as possible. if (!TLO.LegalOps) return false; @@ -1142,7 +1143,7 @@ "i32 or i64 is expected after legalization."); // Exit early if we demand all bits. - if (Demanded.countPopulation() == Size) + if (DemandedBits.countPopulation() == Size) return false; unsigned NewOpc; @@ -1163,7 +1164,7 @@ if (!C) return false; uint64_t Imm = C->getZExtValue(); - return optimizeLogicalImm(Op, Size, Imm, Demanded, TLO, NewOpc); + return optimizeLogicalImm(Op, Size, Imm, DemandedBits, TLO, NewOpc); } /// computeKnownBitsForTargetNode - Determine which of the bits specified in diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -449,10 +449,10 @@ const SelectionDAG &DAG, unsigned Depth) const override; - bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + bool targetShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const override; - bool ExpandInlineAsm(CallInst *CI) const override; ConstraintType getConstraintType(StringRef Constraint) const override; diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -16864,10 +16864,9 @@ } } -bool -ARMTargetLowering::targetShrinkDemandedConstant(SDValue Op, - const APInt &DemandedAPInt, - TargetLoweringOpt &TLO) const { +bool ARMTargetLowering::targetShrinkDemandedConstant( + SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, + TargetLoweringOpt &TLO) const { // Delay optimization, so we don't have to deal with illegal types, or block // optimizations. if (!TLO.LegalOps) @@ -16892,7 +16891,7 @@ unsigned Mask = C->getZExtValue(); - unsigned Demanded = DemandedAPInt.getZExtValue(); + unsigned Demanded = DemandedBits.getZExtValue(); unsigned ShrunkMask = Mask & Demanded; unsigned ExpandedMask = Mask | ~Demanded; diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -1036,7 +1036,8 @@ EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Context, EVT VT) const override; - bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + bool targetShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const override; /// Determine which of the bits specified in Mask are known to be either diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -33225,20 +33225,52 @@ bool X86TargetLowering::targetShrinkDemandedConstant(SDValue Op, - const APInt &Demanded, + const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const { - // Only optimize Ands to prevent shrinking a constant that could be - // matched by movzx. - if (Op.getOpcode() != ISD::AND) - return false; - EVT VT = Op.getValueType(); + unsigned Opcode = Op.getOpcode(); + unsigned EltSize = VT.getScalarSizeInBits(); - // Ignore vectors. - if (VT.isVector()) + if (VT.isVector()) { + // If the constant is only all signbits in the active bits, then we should + // extend it to the entire constant to allow it act as a boolean constant + // vector. + auto NeedsSignExtension = [&](SDValue V, unsigned ActiveBits) { + if (!ISD::isBuildVectorOfConstantSDNodes(V.getNode())) + return false; + for (unsigned i = 0, e = V.getNumOperands(); i != e; ++i) { + if (!DemandedElts[i] || V.getOperand(i).isUndef()) + continue; + const APInt &Val = V.getConstantOperandAPInt(i); + if (Val.getBitWidth() > Val.getNumSignBits() && + Val.trunc(ActiveBits).getNumSignBits() == ActiveBits) + return true; + } + return false; + }; + // For vectors - if we have a constant, then try to sign extend. + // TODO: Handle AND/ANDN cases. + unsigned ActiveBits = DemandedBits.getActiveBits(); + if (EltSize > ActiveBits && EltSize > 1 && isTypeLegal(VT) && + (Opcode == ISD::OR || Opcode == ISD::XOR) && + NeedsSignExtension(Op.getOperand(1), ActiveBits)) { + EVT BoolVT = EVT::getVectorVT(*TLO.DAG.getContext(), MVT::i1, + VT.getVectorNumElements()); + SDValue NewC = + TLO.DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(Op), VT, + Op.getOperand(1), TLO.DAG.getValueType(BoolVT)); + SDValue NewOp = + TLO.DAG.getNode(Opcode, SDLoc(Op), VT, Op.getOperand(0), NewC); + return TLO.CombineTo(Op, NewOp); + } return false; + } - unsigned Size = VT.getSizeInBits(); + // Only optimize Ands to prevent shrinking a constant that could be + // matched by movzx. + if (Opcode != ISD::AND) + return false; // Make sure the RHS really is a constant. ConstantSDNode *C = dyn_cast(Op.getOperand(1)); @@ -33248,7 +33280,7 @@ const APInt &Mask = C->getAPIntValue(); // Clear all non-demanded bits initially. - APInt ShrunkMask = Mask & Demanded; + APInt ShrunkMask = Mask & DemandedBits; // Find the width of the shrunk mask. unsigned Width = ShrunkMask.getActiveBits(); @@ -33260,10 +33292,10 @@ // Find the next power of 2 width, rounding up to a byte. Width = PowerOf2Ceil(std::max(Width, 8U)); // Truncate the width to size to handle illegal types. - Width = std::min(Width, Size); + Width = std::min(Width, EltSize); // Calculate a possible zero extend mask for this constant. - APInt ZeroExtendMask = APInt::getLowBitsSet(Size, Width); + APInt ZeroExtendMask = APInt::getLowBitsSet(EltSize, Width); // If we aren't changing the mask, just return true to keep it and prevent // the caller from optimizing. @@ -33272,7 +33304,7 @@ // Make sure the new mask can be represented by a combination of mask bits // and non-demanded bits. - if (!ZeroExtendMask.isSubsetOf(Mask | ~Demanded)) + if (!ZeroExtendMask.isSubsetOf(Mask | ~DemandedBits)) return false; // Replace the constant with the zero extend mask. diff --git a/llvm/test/CodeGen/X86/promote-cmp.ll b/llvm/test/CodeGen/X86/promote-cmp.ll --- a/llvm/test/CodeGen/X86/promote-cmp.ll +++ b/llvm/test/CodeGen/X86/promote-cmp.ll @@ -30,19 +30,20 @@ ; SSE2-NEXT: pshufd {{.*#+}} xmm5 = xmm7[1,1,3,3] ; SSE2-NEXT: por %xmm4, %xmm5 ; SSE2-NEXT: shufps {{.*#+}} xmm5 = xmm5[0,2],xmm6[0,2] -; SSE2-NEXT: movaps {{.*#+}} xmm4 = <1,1,u,u> -; SSE2-NEXT: xorps %xmm5, %xmm4 -; SSE2-NEXT: shufps {{.*#+}} xmm5 = xmm5[2,1,3,3] +; SSE2-NEXT: pcmpeqd %xmm4, %xmm4 +; SSE2-NEXT: movaps %xmm5, %xmm6 +; SSE2-NEXT: shufps {{.*#+}} xmm6 = xmm6[2,1],xmm5[3,3] +; SSE2-NEXT: psllq $63, %xmm6 +; SSE2-NEXT: psrad $31, %xmm6 +; SSE2-NEXT: pshufd {{.*#+}} xmm6 = xmm6[1,1,3,3] +; SSE2-NEXT: pand %xmm6, %xmm1 +; SSE2-NEXT: pandn %xmm3, %xmm6 +; SSE2-NEXT: por %xmm6, %xmm1 +; SSE2-NEXT: shufps {{.*#+}} xmm5 = xmm5[0,1,1,3] +; SSE2-NEXT: xorps %xmm4, %xmm5 ; SSE2-NEXT: psllq $63, %xmm5 ; SSE2-NEXT: psrad $31, %xmm5 -; SSE2-NEXT: pshufd {{.*#+}} xmm5 = xmm5[1,1,3,3] -; SSE2-NEXT: pand %xmm5, %xmm1 -; SSE2-NEXT: pandn %xmm3, %xmm5 -; SSE2-NEXT: por %xmm5, %xmm1 -; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm4[0,1,1,3] -; SSE2-NEXT: psllq $63, %xmm3 -; SSE2-NEXT: psrad $31, %xmm3 -; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm3[1,1,3,3] +; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm5[1,1,3,3] ; SSE2-NEXT: pand %xmm3, %xmm0 ; SSE2-NEXT: pandn %xmm2, %xmm3 ; SSE2-NEXT: por %xmm3, %xmm0 @@ -56,10 +57,11 @@ ; SSE4-NEXT: movdqa %xmm4, %xmm5 ; SSE4-NEXT: pcmpgtq %xmm2, %xmm5 ; SSE4-NEXT: pshufd {{.*#+}} xmm5 = xmm5[0,2,2,3] -; SSE4-NEXT: pxor {{.*}}(%rip), %xmm5 +; SSE4-NEXT: pcmpeqd %xmm6, %xmm6 +; SSE4-NEXT: pxor %xmm5, %xmm6 ; SSE4-NEXT: psllq $63, %xmm0 ; SSE4-NEXT: blendvpd %xmm0, %xmm1, %xmm3 -; SSE4-NEXT: pmovzxdq {{.*#+}} xmm0 = xmm5[0],zero,xmm5[1],zero +; SSE4-NEXT: pmovzxdq {{.*#+}} xmm0 = xmm6[0],zero,xmm6[1],zero ; SSE4-NEXT: psllq $63, %xmm0 ; SSE4-NEXT: blendvpd %xmm0, %xmm4, %xmm2 ; SSE4-NEXT: movapd %xmm2, %xmm0 @@ -72,9 +74,8 @@ ; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm3 ; AVX1-NEXT: vpcmpgtq %xmm2, %xmm3, %xmm2 ; AVX1-NEXT: vpcmpgtq %xmm1, %xmm0, %xmm3 -; AVX1-NEXT: vpcmpeqd %xmm4, %xmm4, %xmm4 -; AVX1-NEXT: vpxor %xmm4, %xmm3, %xmm3 ; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm3, %ymm2 +; AVX1-NEXT: vxorpd {{.*}}(%rip), %ymm2, %ymm2 ; AVX1-NEXT: vblendvpd %ymm2, %ymm0, %ymm1, %ymm0 ; AVX1-NEXT: retq ; @@ -82,7 +83,6 @@ ; AVX2: # %bb.0: ; AVX2-NEXT: vpcmpgtq %ymm1, %ymm0, %ymm2 ; AVX2-NEXT: vpxor {{.*}}(%rip), %ymm2, %ymm2 -; AVX2-NEXT: vpsllq $63, %ymm2, %ymm2 ; AVX2-NEXT: vblendvpd %ymm2, %ymm0, %ymm1, %ymm0 ; AVX2-NEXT: retq %3 = icmp sgt <4 x i64> %0, %1 diff --git a/llvm/test/CodeGen/X86/setcc-lowering.ll b/llvm/test/CodeGen/X86/setcc-lowering.ll --- a/llvm/test/CodeGen/X86/setcc-lowering.ll +++ b/llvm/test/CodeGen/X86/setcc-lowering.ll @@ -16,8 +16,6 @@ ; AVX-NEXT: vpcmpeqd %xmm2, %xmm0, %xmm0 ; AVX-NEXT: vpackssdw %xmm1, %xmm0, %xmm0 ; AVX-NEXT: vpor {{.*}}(%rip), %xmm0, %xmm0 -; AVX-NEXT: vpsllw $15, %xmm0, %xmm0 -; AVX-NEXT: vpsraw $15, %xmm0, %xmm0 ; AVX-NEXT: vzeroupper ; AVX-NEXT: retq ;