diff --git a/llvm/include/llvm/Analysis/DemandedBits.h b/llvm/include/llvm/Analysis/DemandedBits.h --- a/llvm/include/llvm/Analysis/DemandedBits.h +++ b/llvm/include/llvm/Analysis/DemandedBits.h @@ -61,6 +61,20 @@ void print(raw_ostream &OS); + /// Compute alive bits of one addition operand from alive output and known + /// operand bits + static APInt determineLiveOperandBitsAdd(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS); + + /// Compute alive bits of one subtraction operand from alive output and known + /// operand bits + static APInt determineLiveOperandBitsSub(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS); + private: void performAnalysis(); void determineLiveOperandBits(const Instruction *UserI, diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp --- a/llvm/lib/Analysis/DemandedBits.cpp +++ b/llvm/lib/Analysis/DemandedBits.cpp @@ -173,7 +173,21 @@ } break; case Instruction::Add: + if (AOut.isMask()) { + AB = AOut; + } else { + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2); + } + break; case Instruction::Sub: + if (AOut.isMask()) { + AB = AOut; + } else { + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2); + } + break; case Instruction::Mul: // Find the highest live output bit. We don't need any more input // bits than that (adds, and thus subtracts, ripple only to the @@ -469,6 +483,86 @@ } } +static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS, + bool CarryZero, bool CarryOne) { + assert(!(CarryZero && CarryOne) && + "Carry can't be zero and one at the same time"); + + // The following check should be done by the caller, as it also indicates + // that LHS and RHS don't need to be computed. + // + // if (AOut.isMask()) + // return AOut; + + // Boundary bits' carry out is unaffected by their carry in. + APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One); + + // First, the alive carry bits are determined from the alive output bits: + // Let demand ripple to the right but only up to any set bit in Bound. + // AOut = -1---- + // Bound = ----1- + // ACarry&~AOut = --111- + APInt RBound = Bound.reverseBits(); + APInt RAOut = AOut.reverseBits(); + APInt RProp = RAOut + (RAOut | ~RBound); + APInt RACarry = RProp ^ ~RBound; + APInt ACarry = RACarry.reverseBits(); + + // Then, the alive input bits are determined from the alive carry bits: + APInt NeededToMaintainCarryZero; + APInt NeededToMaintainCarryOne; + if (OperandNo == 0) { + NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero; + NeededToMaintainCarryOne = LHS.One | ~RHS.One; + } else { + NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero; + NeededToMaintainCarryOne = RHS.One | ~LHS.One; + } + + // As in computeForAddCarry + APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero; + APInt PossibleSumOne = LHS.One + RHS.One + CarryOne; + + // The below is simplified from + // + // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero); + // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One; + // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne); + // + // APInt NeededToMaintainCarry = + // (CarryKnownZero & NeededToMaintainCarryZero) | + // (CarryKnownOne & NeededToMaintainCarryOne) | + // CarryUnknown; + + APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) & + (PossibleSumOne | NeededToMaintainCarryOne); + + APInt AB = AOut | (ACarry & NeededToMaintainCarry); + return AB; +} + +APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS) { + return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true, + false); +} + +APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS) { + KnownBits NRHS; + NRHS.Zero = RHS.One; + NRHS.One = RHS.Zero; + return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false, + true); +} + FunctionPass *llvm::createDemandedBitsWrapperPass() { return new DemandedBitsWrapperPass(); } diff --git a/llvm/test/Analysis/DemandedBits/add.ll b/llvm/test/Analysis/DemandedBits/add.ll --- a/llvm/test/Analysis/DemandedBits/add.ll +++ b/llvm/test/Analysis/DemandedBits/add.ll @@ -1,22 +1,22 @@ -; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s -; RUN: opt -S -disable-output -passes="print" < %s 2>&1 | FileCheck %s - -; CHECK-DAG: DemandedBits: 0x1f for %1 = and i32 %a, 9 -; CHECK-DAG: DemandedBits: 0x1f for %2 = and i32 %b, 9 -; CHECK-DAG: DemandedBits: 0x1f for %3 = and i32 %c, 13 -; CHECK-DAG: DemandedBits: 0x1f for %4 = and i32 %d, 4 -; CHECK-DAG: DemandedBits: 0x1f for %5 = or i32 %2, %3 -; CHECK-DAG: DemandedBits: 0x1f for %6 = or i32 %4, %5 +; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s +; RUN: opt -S -disable-output -passes="print" < %s 2>&1 | FileCheck %s + +; CHECK-DAG: DemandedBits: 0x1e for %1 = and i32 %a, 9 +; CHECK-DAG: DemandedBits: 0x1a for %2 = and i32 %b, 9 +; CHECK-DAG: DemandedBits: 0x1a for %3 = and i32 %c, 13 +; CHECK-DAG: DemandedBits: 0x1a for %4 = and i32 %d, 4 +; CHECK-DAG: DemandedBits: 0x1a for %5 = or i32 %2, %3 +; CHECK-DAG: DemandedBits: 0x1a for %6 = or i32 %4, %5 ; CHECK-DAG: DemandedBits: 0x10 for %7 = add i32 %1, %6 ; CHECK-DAG: DemandedBits: 0xffffffff for %8 = and i32 %7, 16 -define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) { - %1 = and i32 %a, 9 - %2 = and i32 %b, 9 - %3 = and i32 %c, 13 - %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero - %5 = or i32 %2, %3 - %6 = or i32 %4, %5 - %7 = add i32 %1, %6 - %8 = and i32 %7, 16 - ret i32 %8 -} \ No newline at end of file +define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) { + %1 = and i32 %a, 9 + %2 = and i32 %b, 9 + %3 = and i32 %c, 13 + %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero + %5 = or i32 %2, %3 + %6 = or i32 %4, %5 + %7 = add i32 %1, %6 + %8 = and i32 %7, 16 + ret i32 %8 +} diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -18,6 +18,7 @@ DataLayoutTest.cpp DebugInfoTest.cpp DebugTypeODRUniquingTest.cpp + DemandedBitsTest.cpp DominatorTreeTest.cpp DominatorTreeBatchUpdatesTest.cpp FunctionTest.cpp diff --git a/llvm/unittests/IR/DemandedBitsTest.cpp b/llvm/unittests/IR/DemandedBitsTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/IR/DemandedBitsTest.cpp @@ -0,0 +1,66 @@ +//===- DemandedBitsTest.cpp - DemandedBits tests --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DemandedBits.h" +#include "../Support/KnownBitsTest.h" +#include "llvm/Support/KnownBits.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +template +static void TestBinOpExhaustive(Fn1 PropagateFn, Fn2 EvalFn) { + unsigned Bits = 4; + unsigned Max = 1 << Bits; + ForeachKnownBits(Bits, [&](const KnownBits &Known1) { + ForeachKnownBits(Bits, [&](const KnownBits &Known2) { + for (unsigned AOut_ = 0; AOut_ < Max; AOut_++) { + APInt AOut(Bits, AOut_); + APInt AB1 = PropagateFn(0, AOut, Known1, Known2); + APInt AB2 = PropagateFn(1, AOut, Known1, Known2); + { + // If the propagator claims that certain known bits + // didn't matter, check it doesn't change its mind + // when they become unknown. + KnownBits Known1Redacted; + KnownBits Known2Redacted; + Known1Redacted.Zero = Known1.Zero & AB1; + Known1Redacted.One = Known1.One & AB1; + Known2Redacted.Zero = Known2.Zero & AB2; + Known2Redacted.One = Known2.One & AB2; + + APInt AB1R = PropagateFn(0, AOut, Known1Redacted, Known2Redacted); + APInt AB2R = PropagateFn(1, AOut, Known1Redacted, Known2Redacted); + EXPECT_EQ(AB1, AB1R); + EXPECT_EQ(AB2, AB2R); + } + ForeachNumInKnownBits(Known1, [&](APInt Value1) { + ForeachNumInKnownBits(Known2, [&](APInt Value2) { + APInt ReferenceResult = EvalFn((Value1 & AB1), (Value2 & AB2)); + APInt Result = EvalFn(Value1, Value2); + EXPECT_EQ(Result & AOut, ReferenceResult & AOut); + }); + }); + } + }); + }); +} + +TEST(DemandedBitsTest, Add) { + TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsAdd, + [](APInt N1, APInt N2) -> APInt { return N1 + N2; }); +} + +TEST(DemandedBitsTest, Sub) { + TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsSub, + [](APInt N1, APInt N2) -> APInt { return N1 - N2; }); +} + +} // anonymous namespace diff --git a/llvm/unittests/Support/KnownBitsTest.h b/llvm/unittests/Support/KnownBitsTest.h new file mode 100644 --- /dev/null +++ b/llvm/unittests/Support/KnownBitsTest.h @@ -0,0 +1,52 @@ +//===- llvm/unittest/Support/KnownBitsTest.h - KnownBits tests ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements helpers for KnownBits and DemandedBits unit tests. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H +#define LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H + +#include "llvm/Support/KnownBits.h" + +namespace { + +using namespace llvm; + +template void ForeachKnownBits(unsigned Bits, FnTy Fn) { + unsigned Max = 1 << Bits; + KnownBits Known(Bits); + for (unsigned Zero = 0; Zero < Max; ++Zero) { + for (unsigned One = 0; One < Max; ++One) { + Known.Zero = Zero; + Known.One = One; + if (Known.hasConflict()) + continue; + + Fn(Known); + } + } +} + +template +void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) { + unsigned Bits = Known.getBitWidth(); + unsigned Max = 1 << Bits; + for (unsigned N = 0; N < Max; ++N) { + APInt Num(Bits, N); + if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0) + continue; + + Fn(Num); + } +} + +} // end anonymous namespace + +#endif diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -11,41 +11,13 @@ //===----------------------------------------------------------------------===// #include "llvm/Support/KnownBits.h" +#include "KnownBitsTest.h" #include "gtest/gtest.h" using namespace llvm; namespace { -template -void ForeachKnownBits(unsigned Bits, FnTy Fn) { - unsigned Max = 1 << Bits; - KnownBits Known(Bits); - for (unsigned Zero = 0; Zero < Max; ++Zero) { - for (unsigned One = 0; One < Max; ++One) { - Known.Zero = Zero; - Known.One = One; - if (Known.hasConflict()) - continue; - - Fn(Known); - } - } -} - -template -void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) { - unsigned Bits = Known.getBitWidth(); - unsigned Max = 1 << Bits; - for (unsigned N = 0; N < Max; ++N) { - APInt Num(Bits, N); - if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0) - continue; - - Fn(Num); - } -} - TEST(KnownBitsTest, AddCarryExhaustive) { unsigned Bits = 4; ForeachKnownBits(Bits, [&](const KnownBits &Known1) {