Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -1589,6 +1589,11 @@ /// Returns true if \p V is a constant integer one. bool isOneConstant(SDValue V); +/// Return the non-bitcasted source operand of \p V if it is bitcasted. +/// If \p V is not a bitcasted value, it is returned as-is. If \p OneUseOnly is +/// true, stop peeking if a bitcast has multiple uses. +SDValue peekThroughBitcasts(SDValue V, bool OneUseOnly = false); + /// Returns true if \p V is a bitwise not operation. Assumes that an all ones /// constant is canonicalized to be operand 1. bool isBitwiseNot(SDValue V); Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -866,12 +866,6 @@ return false; } -static SDValue peekThroughBitcast(SDValue V) { - while (V.getOpcode() == ISD::BITCAST) - V = V.getOperand(0); - return V; -} - // Returns the SDNode if it is a constant float BuildVector // or constant float. static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { @@ -927,7 +921,7 @@ // constant integer of all ones (with no undefs). // Do not permit build vector implicit truncation. static bool isAllOnesConstantOrAllOnesSplatConstant(SDValue N) { - N = peekThroughBitcast(N); + N = peekThroughBitcasts(N); unsigned BitWidth = N.getScalarValueSizeInBits(); if (ConstantSDNode *Splat = isConstOrConstSplat(N)) return Splat->isAllOnesValue() && @@ -13856,7 +13850,7 @@ SDValue Val = St->getValue(); // If constant is of the wrong type, convert it now. if (MemVT != Val.getValueType()) { - Val = peekThroughBitcast(Val); + Val = peekThroughBitcasts(Val); // Deal with constants of wrong size. if (ElementSizeBits != Val.getValueSizeInBits()) { EVT IntMemVT = @@ -13882,7 +13876,7 @@ SmallVector Ops; for (unsigned i = 0; i < NumStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); - SDValue Val = peekThroughBitcast(St->getValue()); + SDValue Val = peekThroughBitcasts(St->getValue()); // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of // type MemVT. If the underlying value is not the correct // type, but it is an extraction of an appropriate vector we @@ -13938,7 +13932,7 @@ StoreSDNode *St = cast(StoreNodes[Idx].MemNode); SDValue Val = St->getValue(); - Val = peekThroughBitcast(Val); + Val = peekThroughBitcasts(Val); StoreInt <<= ElementSizeBits; if (ConstantSDNode *C = dyn_cast(Val)) { StoreInt |= C->getAPIntValue() @@ -14001,7 +13995,7 @@ BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); EVT MemVT = St->getMemoryVT(); - SDValue Val = peekThroughBitcast(St->getValue()); + SDValue Val = peekThroughBitcasts(St->getValue()); // We must have a base and an offset. if (!BasePtr.getBase().getNode()) return; @@ -14035,7 +14029,7 @@ int64_t &Offset) -> bool { if (Other->isVolatile() || Other->isIndexed()) return false; - SDValue Val = peekThroughBitcast(Other->getValue()); + SDValue Val = peekThroughBitcasts(Other->getValue()); // Allow merging constants of different types as integers. bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT()) : Other->getMemoryVT() != MemVT; @@ -14200,7 +14194,7 @@ // Perform an early exit check. Do not bother looking at stored values that // are not constants, loads, or extracted vector elements. - SDValue StoredVal = peekThroughBitcast(St->getValue()); + SDValue StoredVal = peekThroughBitcasts(St->getValue()); bool IsLoadSrc = isa(StoredVal); bool IsConstantSrc = isa(StoredVal) || isa(StoredVal); @@ -14469,7 +14463,7 @@ for (unsigned i = 0; i < NumConsecutiveStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); - SDValue Val = peekThroughBitcast(St->getValue()); + SDValue Val = peekThroughBitcasts(St->getValue()); LoadSDNode *Ld = cast(Val); BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG); @@ -16132,7 +16126,7 @@ // TODO: Maybe this is useful for non-splat too? if (!LegalOperations) { if (SDValue Splat = cast(N)->getSplatValue()) { - Splat = peekThroughBitcast(Splat); + Splat = peekThroughBitcasts(Splat); EVT SrcVT = Splat.getValueType(); if (SrcVT.isVector()) { unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements(); @@ -16267,8 +16261,7 @@ SmallVector Mask; for (SDValue Op : N->ops()) { - // Peek through any bitcast. - Op = peekThroughBitcast(Op); + Op = peekThroughBitcasts(Op); // UNDEF nodes convert to UNDEF shuffle mask values. if (Op.isUndef()) { @@ -16285,9 +16278,7 @@ // We want the EVT of the original extraction to correctly scale the // extraction index. EVT ExtVT = ExtVec.getValueType(); - - // Peek through any bitcast. - ExtVec = peekThroughBitcast(ExtVec); + ExtVec = peekThroughBitcasts(ExtVec); // UNDEF nodes convert to UNDEF shuffle mask values. if (ExtVec.isUndef()) { @@ -16516,7 +16507,7 @@ // We are looking for an optionally bitcasted wide vector binary operator // feeding an extract subvector. - SDValue BinOp = peekThroughBitcast(Extract->getOperand(0)); + SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0)); // TODO: The motivating case for this transform is an x86 AVX1 target. That // target has temptingly almost legal versions of bitwise logic ops in 256-bit @@ -16539,9 +16530,8 @@ if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT)) return SDValue(); - // Peek through bitcasts of the binary operator operands if needed. - SDValue LHS = peekThroughBitcast(BinOp.getOperand(0)); - SDValue RHS = peekThroughBitcast(BinOp.getOperand(1)); + SDValue LHS = peekThroughBitcasts(BinOp.getOperand(0)); + SDValue RHS = peekThroughBitcasts(BinOp.getOperand(1)); // We need at least one concatenation operation of a binop operand to make // this transform worthwhile. The concat must double the input vector sizes. @@ -16639,8 +16629,7 @@ return V->getOperand(Idx / NumElems); } - // Skip bitcasting - V = peekThroughBitcast(V); + V = peekThroughBitcasts(V); // If the input is a build vector. Try to make a smaller build vector. if (V->getOpcode() == ISD::BUILD_VECTOR) { @@ -16936,7 +16925,7 @@ if (!VT.isInteger() || IsBigEndian) return SDValue(); - SDValue N0 = peekThroughBitcast(SVN->getOperand(0)); + SDValue N0 = peekThroughBitcasts(SVN->getOperand(0)); unsigned Opcode = N0.getOpcode(); if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG && @@ -17250,13 +17239,8 @@ N1.isUndef() && Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) { - // Peek through the bitcast only if there is one user. - SDValue BC0 = N0; - while (BC0.getOpcode() == ISD::BITCAST) { - if (!BC0.hasOneUse()) - break; - BC0 = BC0.getOperand(0); - } + // Peek through bitcasts only if there is one user. + SDValue BC0 = peekThroughBitcasts(N0, true); auto ScaleShuffleMask = [](ArrayRef Mask, int Scale) { if (Scale == 1) @@ -17636,7 +17620,7 @@ EVT VT = N->getValueType(0); SDValue LHS = N->getOperand(0); - SDValue RHS = peekThroughBitcast(N->getOperand(1)); + SDValue RHS = peekThroughBitcasts(N->getOperand(1)); SDLoc DL(N); // Make sure we're not running after operation legalization where it Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8190,6 +8190,13 @@ return Const != nullptr && Const->isOne(); } +SDValue llvm::peekThroughBitcasts(SDValue V, bool OneUseOnly) { + while (V.getOpcode() == ISD::BITCAST && + (!OneUseOnly || V.getOperand(0).hasOneUse())) + V = V.getOperand(0); + return V; +} + bool llvm::isBitwiseNot(SDValue V) { if (V.getOpcode() != ISD::XOR) return false; Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -5513,19 +5513,6 @@ return DAG.getVectorShuffle(VT, SDLoc(V2), V1, V2, MaskVec); } -static SDValue peekThroughBitcasts(SDValue V) { - while (V.getNode() && V.getOpcode() == ISD::BITCAST) - V = V.getOperand(0); - return V; -} - -static SDValue peekThroughOneUseBitcasts(SDValue V) { - while (V.getNode() && V.getOpcode() == ISD::BITCAST && - V.getOperand(0).hasOneUse()) - V = V.getOperand(0); - return V; -} - // Peek through EXTRACT_SUBVECTORs - typically used for AVX1 256-bit intops. static SDValue peekThroughEXTRACT_SUBVECTORs(SDValue V) { while (V.getOpcode() == ISD::EXTRACT_SUBVECTOR) @@ -6277,8 +6264,8 @@ case ISD::OR: { // Handle OR(SHUFFLE,SHUFFLE) case where one source is zero and the other // is a valid shuffle index. - SDValue N0 = peekThroughOneUseBitcasts(N.getOperand(0)); - SDValue N1 = peekThroughOneUseBitcasts(N.getOperand(1)); + SDValue N0 = peekThroughBitcasts(N.getOperand(0), true); + SDValue N1 = peekThroughBitcasts(N.getOperand(1), true); if (!N0.getValueType().isVector() || !N1.getValueType().isVector()) return false; SmallVector SrcMask0, SrcMask1; @@ -30646,7 +30633,7 @@ // Directly rip through bitcasts to find the underlying operand. SDValue Op = SrcOps[SrcOpIndex]; - Op = peekThroughOneUseBitcasts(Op); + Op = peekThroughBitcasts(Op, true); MVT VT = Op.getSimpleValueType(); if (!VT.isVector()) @@ -31346,7 +31333,7 @@ V.getOpcode() == X86ISD::PSHUFHW) && V.getOpcode() != N.getOpcode() && V.hasOneUse()) { - SDValue D = peekThroughOneUseBitcasts(V.getOperand(0)); + SDValue D = peekThroughBitcasts(V.getOperand(0), true); if (D.getOpcode() == X86ISD::PSHUFD && D.hasOneUse()) { SmallVector VMask = getPSHUFShuffleMask(V); SmallVector DMask = getPSHUFShuffleMask(D); @@ -31881,7 +31868,7 @@ EVT OriginalVT = InVec.getValueType(); // Peek through bitcasts, don't duplicate a load with other uses. - InVec = peekThroughOneUseBitcasts(InVec); + InVec = peekThroughBitcasts(InVec, true); EVT CurrentVT = InVec.getValueType(); if (!CurrentVT.isVector() || @@ -37156,7 +37143,7 @@ return true; // See if this is a single use constant which can be constant folded. - SDValue BC = peekThroughOneUseBitcasts(Op); + SDValue BC = peekThroughBitcasts(Op, true); return ISD::isBuildVectorOfConstantSDNodes(BC.getNode()); }; @@ -40232,7 +40219,7 @@ } // If lower/upper loads are the same and there's no other use of the lower // load, then splat the loaded value with a broadcast. - if (auto *Ld = dyn_cast(peekThroughOneUseBitcasts(SubVec2))) + if (auto *Ld = dyn_cast(peekThroughBitcasts(SubVec2, true))) if (SubVec2 == SubVec && ISD::isNormalLoad(Ld) && Vec.hasOneUse()) return DAG.getNode(X86ISD::SUBV_BROADCAST, dl, OpVT, SubVec);