diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -272,6 +272,7 @@ "LHS and RHS should have the same type"); assert(LHS->getType()->isIntOrIntVectorTy() && "LHS and RHS should be integers"); + // Look for an inverted mask: (X & ~M) op (Y & M). { Value *M; @@ -306,7 +307,39 @@ match(LHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) return true; } - IntegerType *IT = cast(LHS->getType()->getScalarType()); + + auto *IT = cast(LHS->getType()->getScalarType()); + if (AC && CxtI) { + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + CallInst *I = cast(AssumeVH); + assert(I->getFunction() == CxtI->getFunction() && + "Got assumption for the wrong function!"); + + // Warning: This loop can end up being somewhat performance sensitive. + // We're running this loop once for each value queried resulting in a + // runtime of ~O(#assumes * #values). + + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + Value *Src = I->getArgOperand(0); + + // Check for assume(and(lhs, rhs) == 0). + CmpInst::Predicate Pred; + if (match(Src, m_c_ICmp(Pred, m_c_And(m_Specific(LHS), m_Specific(RHS)), + m_Zero())) && + CmpInst::ICMP_EQ == Pred) + return true; + + // Check for boolean assume(!and(lhs, rhs)). + if (IT->getBitWidth() == 1 && + match(Src, m_Not(m_c_And(m_Specific(LHS), m_Specific(RHS))))) + return true; + } + } + KnownBits LHSKnown(IT->getBitWidth()); KnownBits RHSKnown(IT->getBitWidth()); computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3795,6 +3795,10 @@ if (Instruction *Xor = visitMaskedMerge(I, Builder)) return Xor; + // A^B --> A|B iff A and B have no bits set in common. + if (llvm::haveNoCommonBitsSet(Op0, Op1, DL, &AC, &I, &DT)) + return BinaryOperator::CreateOr(Op0, Op1); + Value *X, *Y; Constant *C1; if (match(Op1, m_Constant(C1))) { diff --git a/llvm/test/Transforms/InstCombine/xor.ll b/llvm/test/Transforms/InstCombine/xor.ll --- a/llvm/test/Transforms/InstCombine/xor.ll +++ b/llvm/test/Transforms/InstCombine/xor.ll @@ -6,6 +6,7 @@ declare i32 @llvm.ctlz.i32(i32, i1) declare <2 x i8> @llvm.cttz.v2i8(<2 x i8>, i1) +declare void @llvm.assume(i1) declare void @use(i8) define i1 @test0(i1 %A) { @@ -1161,7 +1162,7 @@ define <2 x i32> @xor_andn_commute1(<2 x i32> %a, <2 x i32> %b) { ; CHECK-LABEL: @xor_andn_commute1( -; CHECK-NEXT: [[Z:%.*]] = or <2 x i32> [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[Z:%.*]] = or <2 x i32> [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: ret <2 x i32> [[Z]] ; %nota = xor <2 x i32> %a, @@ -1397,3 +1398,70 @@ %r = xor i32 %z, 30 ret i32 %r } + +; +; (xor x, y) -> (or x, y) iff x and y have no common bits +; + +define i32 @test_nocommonbits_i32(i32 %x, i32 %y) { +; CHECK-LABEL: @test_nocommonbits_i32( +; CHECK-NEXT: [[COMMONBITS:%.*]] = and i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[NOCOMMONBITS:%.*]] = icmp eq i32 [[COMMONBITS]], 0 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[NOCOMMONBITS]]) +; CHECK-NEXT: [[R:%.*]] = or i32 [[X]], [[Y]] +; CHECK-NEXT: ret i32 [[R]] +; + %commonbits = and i32 %x, %y + %nocommonbits = icmp eq i32 %commonbits, 0 + tail call void @llvm.assume(i1 %nocommonbits) + %r = xor i32 %x, %y + ret i32 %r +} + +; handle icmpeq -> not case for bool types +define i1 @test_nocommonbits_i1(i1 %x, i1 %y) { +; CHECK-LABEL: @test_nocommonbits_i1( +; CHECK-NEXT: [[COMMONBITS:%.*]] = and i1 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[NOCOMMONBITS:%.*]] = xor i1 [[COMMONBITS]], true +; CHECK-NEXT: tail call void @llvm.assume(i1 [[NOCOMMONBITS]]) +; CHECK-NEXT: [[R:%.*]] = or i1 [[X]], [[Y]] +; CHECK-NEXT: ret i1 [[R]] +; + %commonbits = and i1 %x, %y + %nocommonbits = xor i1 %commonbits, 1 + tail call void @llvm.assume(i1 %nocommonbits) + %r = xor i1 %x, %y + ret i1 %r +} + +; negative test - wrong predicate +define i32 @test_nocommonbits_i32_ne(i32 %x, i32 %y) { +; CHECK-LABEL: @test_nocommonbits_i32_ne( +; CHECK-NEXT: [[COMMONBITS:%.*]] = and i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[NOCOMMONBITS:%.*]] = icmp ne i32 [[COMMONBITS]], 0 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[NOCOMMONBITS]]) +; CHECK-NEXT: [[R:%.*]] = xor i32 [[X]], [[Y]] +; CHECK-NEXT: ret i32 [[R]] +; + %commonbits = and i32 %x, %y + %nocommonbits = icmp ne i32 %commonbits, 0 + tail call void @llvm.assume(i1 %nocommonbits) + %r = xor i32 %x, %y + ret i32 %r +} + +; negative test - wrong binop +define i32 @test_nocommonbits_i32_add(i32 %x, i32 %y) { +; CHECK-LABEL: @test_nocommonbits_i32_add( +; CHECK-NEXT: [[COMMONBITS:%.*]] = mul i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[NOCOMMONBITS:%.*]] = icmp eq i32 [[COMMONBITS]], 0 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[NOCOMMONBITS]]) +; CHECK-NEXT: [[R:%.*]] = xor i32 [[X]], [[Y]] +; CHECK-NEXT: ret i32 [[R]] +; + %commonbits = mul i32 %x, %y + %nocommonbits = icmp eq i32 %commonbits, 0 + tail call void @llvm.assume(i1 %nocommonbits) + %r = xor i32 %x, %y + ret i32 %r +}