Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7891,6 +7891,206 @@ return true; } +static bool isLogicalShift(SDNode *N) { + unsigned Opc = N->getOpcode(); + return Opc == ISD::SHL || Opc == ISD::SRL; +} + +/// A potential constituent of a rev expression. See collectBitParts for a +/// fuller explanation. +struct BitPart { + BitPart(const SDValue *P, unsigned BW) : Provider(P) { + Provenance.resize(BW); + } + /// The SDValue that this is a rev16/rev32/rev64 of. + const SDValue *Provider; + + /// The "provenance" of each bit. Provenance[A] = B means that bit A + /// in Provider becomes bit B in the result of this expression. + SmallVector Provenance; // int8_t means max size is i128. + + enum { Unset = -1 }; +}; + +/// Analyze the specified subexpression and see if it is capable of providing +/// pieces of a rev. The subexpression provides a potential piece of a rev if it +/// can be proven that each non-zero bit in the output of the expression came +/// from a corresponding bit in some other value. This function is recursive, +/// and the end result is a mapping of bitnumber to bitnumber. It is the +/// caller's responsibility to validate that the bitnumber to bitnumber mapping +/// is correct for a rev instruction. +/// +/// To avoid revisiting values, the BitPart results are memoized into the +/// provided map. To avoid unnecessary copying of BitParts, BitParts are +/// constructed in-place in the \c BPS map. Because of this \c BPS needs to +/// store BitParts objects, not pointers. As we need the concept of a nullptr +/// BitParts (SDValue has been analyzed and the analysis failed), we use an +/// Optional type instead to provide the same functionality. +/// +/// Because we pass around references into \c BPS, we must use a container that +/// does not invalidate internal references (std::map instead of DenseMap). +/// +static const Optional & +collectBitParts(const SDValue &N, + std::map> &BPS) { + auto I = BPS.find(&N); + if (I != BPS.end()) + return I->second; + + auto &Result = BPS[&N] = None; + EVT VT = N.getValueType(); + unsigned BitWidth = VT.getSizeInBits(); + + // If this is an or instruction, it may be an inner node of the rev. + if (N.getOpcode() == ISD::OR) { + auto &A = collectBitParts(N.getOperand(0), BPS); + if (!A || !A->Provider) + return Result; + + auto &B = collectBitParts(N.getOperand(1), BPS); + if (!B || !B->Provider) + return Result; + + // Try and merge the two together. + if (A->Provider->getNode() != B->Provider->getNode()) + return Result; + + Result = BitPart(A->Provider, BitWidth); + for (unsigned i = 0; i < A->Provenance.size(); ++i) { + if (A->Provenance[i] != BitPart::Unset && + B->Provenance[i] != BitPart::Unset && + A->Provenance[i] != B->Provenance[i]) + return Result = None; + + if (A->Provenance[i] == BitPart::Unset) + Result->Provenance[i] = B->Provenance[i]; + else + Result->Provenance[i] = A->Provenance[i]; + } + return Result; + } + // If this is a logical shift by a constant, recurse then shift the result. + if (isLogicalShift(N.getNode()) && isa(N.getOperand(1))) { + N.dump(); + uint64_t BitShift = N.getConstantOperandVal(1); + // Ensure the shift amount is defined. + if (BitShift > BitWidth) + return Result; + + auto &Res = collectBitParts(N.getOperand(0), BPS); + if (!Res) + return Result; + + Result = Res; + + // Perform the "shift" on BitProvenance. + auto &P = Result->Provenance; + if (N.getOpcode() == ISD::SHL) { + P.erase(std::prev(P.end(), BitShift), P.end()); + P.insert(P.begin(), BitShift, BitPart::Unset); + } else { + P.erase(P.begin(), std::next(P.begin(), BitShift)); + P.insert(P.end(), BitShift, BitPart::Unset); + } + return Result; + } + // If this is a logical 'and' with a mask that clears bits, recurse then + // unset the appropriate bits. + if (N.getOpcode() == ISD::AND && isa(N.getOperand(1))) { + N.dump(); + uint64_t Bit = 1; + uint64_t AndMask = N.getConstantOperandVal(1); + + auto &Res = collectBitParts(N.getOperand(0), BPS); + if (!Res) + return Result; + Result = Res; + + for (unsigned i = 0; i < BitWidth; ++i, Bit <<= 1) + // If the AndMask is zero for this bit, clear the bit. + if ((AndMask & Bit) == 0) + Result->Provenance[i] = BitPart::Unset; + + return Result; + } + + // Okay, we got to something that isn't a shift, 'or' or 'and'. This must be + // the input value to the rev. + Result = BitPart(&N, BitWidth); + for (unsigned i = 0; i < BitWidth; ++i) + Result->Provenance[i] = i; + return Result; +} + +static bool bitTransformIsCorrectForRev(unsigned From, unsigned To, + unsigned BitWidth, unsigned WordWidth) { + if (From % WordWidth != To % WordWidth) + return false; + // Convert from bit indices to byte, halfword, or word indices. + From >>= Log2_32(WordWidth); + To >>= Log2_32(WordWidth); + BitWidth >>= Log2_32(WordWidth); + return From == BitWidth - To - 1; +} + +static SDValue tryCombineToREV(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + EVT VT = N->getValueType(0); + unsigned BitWidth = VT.getSizeInBits(); + if (VT != MVT::i32 && VT != MVT::i64) + return SDValue(); + + N->dumpr(); + +#if 0 + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // (A | B) | C and A | (B | C) + bool OrOfOrs = N0.getOpcode() == ISD::OR && N1.getOpcode() == ISD::OR; + // (A >> B) | (C << D) and (A << B) | (B >> C) + bool OrOfShifts = + isLogicalShift(N0.getNode()) && isLogicalShift(N1.getNode()); + // (A & B) | (C & D) + bool OrOfAnds = N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND; + if (!OrOfOrs && !OrOfShifts && OrOfAnds) + return SDValue(); +#endif + // Try to find all the pieces corresponding to the bswap. + std::map> BPS; + auto Res = collectBitParts(SDValue(N, 0), BPS); + if (!Res) + return SDValue(); + + auto &BitProvenance = Res->Provenance; + + // Now, is the bit permutation correct for a bswap or a bitreverse? We can + // only byteswap values with an even number of bytes. + bool IsRBit = true; + bool IsRev = VT == MVT::i32 || VT == MVT::i64; + bool IsRev16 = VT == MVT::i32 || VT == MVT::i64; + bool IsRev32 = VT == MVT::i64; + for (unsigned i = 0; i < BitWidth; ++i) { + IsRBit &= bitTransformIsCorrectForRev(BitProvenance[i], i, BitWidth, 1); + IsRev &= bitTransformIsCorrectForRev(BitProvenance[i], i, BitWidth, 8); + IsRev16 &= bitTransformIsCorrectForRev(BitProvenance[i], i, BitWidth, 16); + IsRev32 &= bitTransformIsCorrectForRev(BitProvenance[i], i, BitWidth, 32); + } + if (IsRBit) { + assert(0 && "RBit"); + } + if (IsRev) { + assert(0 && "Rev"); + } + if (IsRev16) { + assert(0 && "Rev16"); + } + if (IsRev32) { + assert(0 && "Rev32"); + } + return SDValue(); +} + /// EXTR instruction extracts a contiguous chunk of bits from two existing /// registers viewed as a high/low pair. This function looks for the pattern: /// (or (shl VAL1, #N), (srl VAL2, #RegWidth-N)) and replaces it with an @@ -7992,6 +8192,9 @@ if (!DAG.getTargetLoweringInfo().isTypeLegal(VT)) return SDValue(); + if (SDValue Res = tryCombineToREV(N, DCI)) + return Res; + if (SDValue Res = tryCombineToEXTR(N, DCI)) return Res;