Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -764,9 +764,37 @@ } /// Fold (icmp)&(icmp) if possible. -Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { +Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, + Instruction *CxtI) { ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) + // if K1 and K2 are a one-bit mask. + ConstantInt *LHSC = dyn_cast(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast(RHS->getOperand(1)); + + // TODO support vector splats + if (LHS->getPredicate() == ICmpInst::ICMP_NE && LHSC && LHSC->isZero() && + RHS->getPredicate() == ICmpInst::ICMP_NE && RHSC && RHSC->isZero()) { + + Value *A, *B, *C, *D; + if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) && + match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) { + if (A == D || B == D) + std::swap(C, D); + if (B == C) + std::swap(A, B); + + if (A == C && + isKnownToBeAPowerOfTwo(B, false, 0, CxtI) && + isKnownToBeAPowerOfTwo(D, false, 0, CxtI)) { + Value *Mask = Builder->CreateOr(B, D); + Value *Masked = Builder->CreateAnd(A, Mask); + return Builder->CreateICmp(ICmpInst::ICMP_EQ, Masked, Mask); + } + } + } + // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) if (PredicatesFoldable(PredL, PredR)) { if (LHS->getOperand(0) == RHS->getOperand(1) && @@ -798,8 +826,6 @@ // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); - ConstantInt *LHSC = dyn_cast(LHS->getOperand(1)); - ConstantInt *RHSC = dyn_cast(RHS->getOperand(1)); if (!LHSC || !RHSC) return nullptr; @@ -1127,7 +1153,7 @@ ICmpInst *ICmp0 = dyn_cast(Cast0Src); ICmpInst *ICmp1 = dyn_cast(Cast1Src); if (ICmp0 && ICmp1) { - Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1) + Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1, &I) : foldOrOfICmps(ICmp0, ICmp1, &I); if (Res) return CastInst::Create(CastOpcode, Res, DestTy); @@ -1426,7 +1452,7 @@ ICmpInst *LHS = dyn_cast(Op0); ICmpInst *RHS = dyn_cast(Op1); if (LHS && RHS) - if (Value *Res = foldAndOfICmps(LHS, RHS)) + if (Value *Res = foldAndOfICmps(LHS, RHS, &I)) return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary @@ -1434,18 +1460,18 @@ Value *X, *Y; if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast(X)) - if (Value *Res = foldAndOfICmps(LHS, Cmp)) + if (Value *Res = foldAndOfICmps(LHS, Cmp, &I)) return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast(Y)) - if (Value *Res = foldAndOfICmps(LHS, Cmp)) + if (Value *Res = foldAndOfICmps(LHS, Cmp, &I)) return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); } if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast(X)) - if (Value *Res = foldAndOfICmps(Cmp, RHS)) + if (Value *Res = foldAndOfICmps(Cmp, RHS, &I)) return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast(Y)) - if (Value *Res = foldAndOfICmps(Cmp, RHS)) + if (Value *Res = foldAndOfICmps(Cmp, RHS, &I)) return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); } } Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -447,7 +447,7 @@ Instruction::CastOps isEliminableCastPair(const CastInst *CI1, const CastInst *CI2); - Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI); Value *foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI); Value *foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Index: test/Transforms/InstCombine/onehot_merge.ll =================================================================== --- test/Transforms/InstCombine/onehot_merge.ll +++ test/Transforms/InstCombine/onehot_merge.ll @@ -55,3 +55,57 @@ ret i1 %or } +define i1 @or_consts(i32 %k, i32 %c1, i32 %c2) { +; CHECK-LABEL: @or_consts( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[K:%.*]], 12 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[TMP1]], 12 +; CHECK-NEXT: ret i1 [[TMP2]] +; + %tmp1 = and i32 4, %k + %tmp2 = icmp ne i32 %tmp1, 0 + %tmp5 = and i32 8, %k + %tmp6 = icmp ne i32 %tmp5, 0 + %or = and i1 %tmp2, %tmp6 + ret i1 %or +} + +define i1 @foo1_or(i32 %k, i32 %c1, i32 %c2) { +; CHECK-LABEL: @foo1_or( +; CHECK-NEXT: [[TMP:%.*]] = shl i32 1, [[C1:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 -2147483648, [[C2:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[TMP]], [[TMP4]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], [[K:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i32 [[TMP2]], [[TMP1]] +; CHECK-NEXT: ret i1 [[TMP3]] +; + %tmp = shl i32 1, %c1 + %tmp4 = lshr i32 -2147483648, %c2 + %tmp1 = and i32 %tmp, %k + %tmp2 = icmp ne i32 %tmp1, 0 + %tmp5 = and i32 %tmp4, %k + %tmp6 = icmp ne i32 %tmp5, 0 + %or = and i1 %tmp2, %tmp6 + ret i1 %or +} + +; Same as above but with operands commuted one of the ors, but not the other. +define i1 @foo1_or_commuted(i32 %k, i32 %c1, i32 %c2) { +; CHECK-LABEL: @foo1_or_commuted( +; CHECK-NEXT: [[K2:%.*]] = mul i32 [[K:%.*]], [[K]] +; CHECK-NEXT: [[TMP:%.*]] = shl i32 1, [[C1:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 -2147483648, [[C2:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[TMP]], [[TMP4]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[K2]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i32 [[TMP2]], [[TMP1]] +; CHECK-NEXT: ret i1 [[TMP3]] +; + %k2 = mul i32 %k, %k ; to trick the complexity sorting + %tmp = shl i32 1, %c1 + %tmp4 = lshr i32 -2147483648, %c2 + %tmp1 = and i32 %k2, %tmp + %tmp2 = icmp ne i32 %tmp1, 0 + %tmp5 = and i32 %tmp4, %k2 + %tmp6 = icmp ne i32 %tmp5, 0 + %or = and i1 %tmp2, %tmp6 + ret i1 %or +}