diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp --- a/llvm/lib/Analysis/DemandedBits.cpp +++ b/llvm/lib/Analysis/DemandedBits.cpp @@ -173,7 +173,30 @@ } break; case Instruction::Add: - case Instruction::Sub: + case Instruction::Sub: { + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + // When Bound == 0, this should behave just like + // AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits()); + APInt Bound1; + if (UserI->getOpcode() == Instruction::Add) + Bound1 = Known.Zero; + else + Bound1 = Known.One; + APInt Bound2 = Known2.Zero; + APInt Bound = Bound1 & Bound2; + APInt RBound = Bound.reverseBits(); + APInt RAOut = AOut.reverseBits(); + APInt RProp = RAOut + (RAOut | ~RBound); + APInt RQ = (RProp ^ ~(RAOut | RBound)); + APInt Q = RQ.reverseBits(); + APInt U; + if (OperandNo == 0) + U = Bound1 | ~Bound2; + else + U = Bound2 | ~Bound1; + AB = AOut | (Q & (U | (Bound1 + Bound2 + 1))); + break; + } case Instruction::Mul: // Find the highest live output bit. We don't need any more input // bits than that (adds, and thus subtracts, ripple only to the diff --git a/llvm/test/Analysis/DemandedBits/add.ll b/llvm/test/Analysis/DemandedBits/add.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Analysis/DemandedBits/add.ll @@ -0,0 +1,18 @@ +; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s +; RUN: opt -S -disable-output -passes="print" < %s 2>&1 | FileCheck %s + +; CHECK-DAG: DemandedBits: 0x1e for %1 = and i32 %a, 9 +; CHECK-DAG: DemandedBits: 0x1a for %2 = and i32 %b, 9 +; CHECK-DAG: DemandedBits: 0x1a for %3 = and i32 %c, 13 +; CHECK-DAG: DemandedBits: 0x1a for %4 = and i32 %d, 4 +define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) { + %1 = and i32 %a, 9 + %2 = and i32 %b, 9 + %3 = and i32 %c, 13 + %4 = and i32 %d, 4 ; dead + %5 = or i32 %2, %3 + %6 = or i32 %4, %5 + %7 = add i32 %1, %6 + %8 = and i32 %7, 16 + ret i32 %8 +} diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -16,6 +16,7 @@ DataLayoutTest.cpp DebugInfoTest.cpp DebugTypeODRUniquingTest.cpp + DemandedBitsTest.cpp DominatorTreeTest.cpp DominatorTreeBatchUpdatesTest.cpp FunctionTest.cpp diff --git a/llvm/unittests/IR/DemandedBitsTest.cpp b/llvm/unittests/IR/DemandedBitsTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/IR/DemandedBitsTest.cpp @@ -0,0 +1,154 @@ +//===- DemandedBitsTest.cpp - DemandedBits tests --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/Support/KnownBits.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +template +static void EnumerateKnownBits(unsigned Bits, Fn TestFn) { + unsigned Max = 1 << Bits; + for (unsigned Zero = 0; Zero < Max; Zero++) { + for (unsigned One = 0; One < Max; One = ((One | Zero) + 1) & ~Zero) { + KnownBits Known; + Known.Zero = APInt(Bits, Zero); + Known.One = APInt(Bits, One); + TestFn(Known); + } + } +} + +template +static void EnumerateTwoKnownBits(unsigned Bits, Fn TestFn) { + EnumerateKnownBits(Bits, [&](const KnownBits &Known1) { + EnumerateKnownBits( + Bits, [&](const KnownBits &Known2) { TestFn(Known1, Known2); }); + }); +} + +template +static void EnumerateRemainingBits(const KnownBits &Known, Fn TestFn) { + unsigned Max = 1 << Known.getBitWidth(); + unsigned Mask = (Known.Zero | Known.One).getLimitedValue(); + for (unsigned Remaining = 0; Remaining < Max; + Remaining = ((Remaining | Mask) + 1) & ~Mask) { + TestFn(Known.One | Remaining); + } +} + +template +static void PropagateBinOp(const KnownBits &Known1, const KnownBits &Known2, + const APInt &AOut, APInt &AB1, APInt &AB2, + Fn BuildOp) { + unsigned Bits = AOut.getBitWidth(); + + LLVMContext C; + Module M("test", C); + Type *NumericTy = Type::getIntNTy(C, Bits); + + Type *ArgTypes[] = {NumericTy, NumericTy}; + Function *F = Function::Create( + FunctionType::get(NumericTy, ArrayRef(ArgTypes, 2), false), + GlobalValue::ExternalLinkage, "F", &M); + BasicBlock *BB = BasicBlock::Create(C, "", F); + IRBuilder<> Builder(BB); + + Value *Known1NZero = Builder.getInt(~Known1.Zero); + Value *Known1One = Builder.getInt(Known1.One); + Value *Known2NZero = Builder.getInt(~Known2.Zero); + Value *Known2One = Builder.getInt(Known2.One); + Value *AOutImm = Builder.getInt(AOut); + + Argument *Arg1 = F->getArg(0); + Argument *Arg2 = F->getArg(1); + Value *Dummy1 = Builder.CreateNot(Arg1); + Value *Dummy2 = Builder.CreateNot(Arg2); + Value *Op1 = + Builder.CreateOr(Builder.CreateAnd(Dummy1, Known1NZero), Known1One); + Value *Op2 = + Builder.CreateOr(Builder.CreateAnd(Dummy2, Known2NZero), Known2One); + + Value *IOp = BuildOp(Builder, Op1, Op2); + + Value *IMasked = Builder.CreateAnd(IOp, AOutImm); + Builder.CreateRet(IMasked); + + AssumptionCache AC(*F); + DominatorTree DT(*F); + DemandedBits DB(*F, AC, DT); + AB1 = DB.getDemandedBits(dyn_cast(Op1)); + AB2 = DB.getDemandedBits(dyn_cast(Op2)); +} + +template +static void TestBinOpExhaustive(Fn1 BuildOp, Fn2 EvalFn) { + unsigned Bits = 4; + unsigned Max = 1 << Bits; + EnumerateTwoKnownBits( + Bits, [&](const KnownBits &Known1, const KnownBits &Known2) { + for (unsigned AOut_ = 0; AOut_ < Max; AOut_++) { + APInt AOut(Bits, AOut_); + APInt AB1; + APInt AB2; + PropagateBinOp(Known1, Known2, AOut, AB1, AB2, BuildOp); + { + // If the propagator claims that certain known bits + // didn't matter, check that the result doesn't + // change when they become unknown + KnownBits Known1_; + KnownBits Known2_; + Known1_.Zero = Known1.Zero & AB1; + Known1_.One = Known1.One & AB1; + Known2_.Zero = Known2.Zero & AB2; + Known2_.One = Known2.One & AB2; + + APInt AB1_; + APInt AB2_; + PropagateBinOp(Known1_, Known2_, AOut, AB1_, AB2_, BuildOp); + EXPECT_EQ(AB1, AB1_); + EXPECT_EQ(AB1, AB1_); + } + APInt Z = EvalFn(Known1.One, Known2.One); + EnumerateRemainingBits(Known1, [&](APInt Value1) { + EnumerateRemainingBits(Known2, [&](APInt Value2) { + APInt ReferenceResult = EvalFn((Value1 & AB1), (Value2 & AB2)); + APInt Result = EvalFn(Value1, Value2); + EXPECT_EQ(Result & AOut, ReferenceResult & AOut); + }); + }); + } + }); +} + +TEST(DemandedBitsTest, Add) { + TestBinOpExhaustive( + [](IRBuilder<> &Builder, Value *Op1, Value *Op2) -> Value * { + return Builder.CreateAdd(Op1, Op2); + }, + [](APInt N1, APInt N2) -> APInt { return N1 + N2; }); +} + +TEST(DemandedBitsTest, Sub) { + TestBinOpExhaustive( + [](IRBuilder<> &Builder, Value *Op1, Value *Op2) -> Value * { + return Builder.CreateSub(Op1, Op2); + }, + [](APInt N1, APInt N2) -> APInt { return N1 - N2; }); +} + +} // anonymous namespace