diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14332,6 +14332,61 @@ return true; } + case Instruction::And: { + // If Or(And(A, maskValue), And(B, ~maskValue)), sink ~maskValue + // where ~maskValue = xor maskValue, -1 + // This is to select more performant bitselect instruction with NEON. + + // We start checking for this code sequence from AND instruction containing + // ~maskValue because operands are sunk just before the Instruction* I + // passed to this function. Starting pattern matching with any other + // instruction (such as Or) would lead to malformed IR + + // Check if this AND instruction is part of bigger tree rooted at Or. + if (Subtarget->hasNEON() && I->getNumUses() == 1) { + Use &U = *(I->use_begin()); + Instruction *OI = cast(U.getUser()); + Value *And0_Op0 = nullptr, *And0_Op1 = nullptr, *And1_Op0 = nullptr, + *And1_Op1 = nullptr; + if (!OI || + !match(OI, m_Or(m_And(m_Value(And0_Op0), m_Value(And0_Op1)), + m_And(m_Value(And1_Op0), m_Value(And1_Op1)))) || + !all_of(OI->operands(), [](Value *V) { return V->hasOneUser(); })) + return false; + + ArrayRef> Ands = {{And0_Op0, And0_Op1}, + {And1_Op0, And1_Op1}}; + for (unsigned AndIndex = 0; AndIndex < Ands.size(); ++AndIndex) { + const unsigned OtherAndIndex = (AndIndex + 1) % 2; + + // Iterate operands of selected And + for (unsigned AndOpIndex = 0; AndOpIndex < 2; ++AndOpIndex) { + if (const Instruction *XI = + dyn_cast(Ands[AndIndex][AndOpIndex]); + XI && XI->getOpcode() == Instruction::Xor) { + Constant *MaskConst; + Value *MaskValue; + if (!match(XI, m_Xor(m_Value(MaskValue), m_Constant(MaskConst))) && + !match(XI, m_Xor(m_Constant(MaskConst), m_Value(MaskValue)))) + return false; + + if (!MaskConst->isAllOnesValue()) + return false; + + // one of the operands of other AND should be MaskValue + if (!any_of(Ands[OtherAndIndex], + [&MaskValue](Value *V) { return V == MaskValue; })) + return false; + + auto TI = cast(OI->getOperand(AndIndex)); + Ops.push_back(&TI->getOperandUse(AndOpIndex)); + return true; + } + } + } + } + return false; + } case Instruction::Mul: { int NumZExts = 0, NumSExts = 0; for (auto &Op : I->operands()) {