Index: lib/Transforms/Scalar/StraightLineStrengthReduce.cpp =================================================================== --- lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -15,36 +15,34 @@ // // There are many optimizations we can perform in the domain of SLSR. This file // for now contains only an initial step. Specifically, we look for strength -// reduction candidates in two forms: +// reduction candidates in the folowing forms: // // Form 1: (B + i) * S // Form 2: &B[i * S] +// Form 3: (B1 + S) + B2 // -// where S is an integer variable, and i is a constant integer. If we found two -// candidates +// where B, B1, B2, and S are integer variables, and i is a constant integer. If +// we found two candidates S1 and S2 in the same form and S1 dominates S2, we +// may be able to rewrite S2 in a simpler way with respect to S1. For example, // // S1: X = (B + i) * S -// S2: Y = (B + i') * S -// -// or +// S2: Y = (B + i') * S => Y = X + (i' - i) * S // // S1: X = &B[i * S] -// S2: Y = &B[i' * S] -// -// and S1 dominates S2, we call S1 a basis of S2, and can replace S2 with -// -// Y = X + (i' - i) * S +// S2: Y = &B[i' * S] => Y = &X[(i' - i) * S] // -// or +// S1: X = B1 + B2 +// S2: Y = (B1 + S) + B2 => Y = X + S // -// Y = &X[(i' - i) * S] -// -// where (i' - i) * S is folded to the extent possible. When S2 has multiple -// bases, we pick the one that is closest to S2, or S2's "immediate" basis. +// When such rewriting is possible, we call S1 a "basis" of S2. When S2 has +// multiple bases, we choose to rewrite S2 with respect to the closest basis or +// the "immediate" basis. // // TODO: // -// - Handle candidates in the form of B + i * S +// - Handle candidates in the form of B + i * S. +// +// - Handle integer vectors. // // - Floating point arithmetics when fast math is enabled. // @@ -62,12 +60,15 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" using namespace llvm; using namespace PatternMatch; +#define DEBUG_TYPE "slsr" + namespace { class StraightLineStrengthReduce : public FunctionPass { @@ -78,9 +79,10 @@ // Base[..][Index * Stride][..] struct Candidate : public ilist_node { enum Kind { - Invalid, // reserved for the default constructor - Mul, // (B + i) * S - GEP, // &B[..][i * S][..] + Invalid, // reserved for the default constructor + Mul, // (B + i) * S + GEP, // &B[..][i * S][..] + TernaryAdd, // B + i * S where B is an addition too }; Candidate() @@ -90,12 +92,16 @@ Instruction *I) : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I), Basis(nullptr) {} + void print(raw_ostream &OS) const; + Kind CandidateKind; const SCEV *Base; // Note that Index and Stride of a GEP candidate may not have the same // integer type. In that case, during rewriting, Stride will be // sign-extended or truncated to Index's type. ConstantInt *Index; + // Stride can be nullptr indicating the wildcard. In that case, Index must + // be zero. Value *Stride; // The instruction this candidate corresponds to. It helps us to rewrite a // candidate with respect to its immediate basis. Note that one instruction @@ -106,11 +112,11 @@ // // can be treated as // - // + // [slsr-mul] base: a, index: 1, stride: b + 2 // // or // - // + // [slsr-mul] base: b, index: 2, stride: a + 1 Instruction *Ins; // Points to the immediate basis of this candidate, or nullptr if we cannot // find any basis for this candidate. @@ -125,6 +131,7 @@ } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved(); AU.addRequired(); AU.addRequired(); AU.addRequired(); @@ -139,6 +146,10 @@ bool runOnFunction(Function &F) override; +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + void dumpAllCandidates() const; +#endif + private: // Returns true if Basis is a basis for C, i.e., Basis dominates C and they // share the same base and stride. @@ -146,6 +157,8 @@ // Checks whether I is in a candidate form. If so, adds all the matching forms // to Candidates, and tries to find the immediate basis for each of them. void allocateCandidateAndFindBasis(Instruction *I); + // Allocate candidates and find bases for Add instructions. + void allocateCandidateAndFindBasisForAdd(Instruction *I); // Allocate candidates and find bases for Mul instructions. void allocateCandidateAndFindBasisForMul(Instruction *I); // Splits LHS into Base + Index and, if succeeds, calls @@ -183,6 +196,7 @@ DominatorTree *DT; ScalarEvolution *SE; TargetTransformInfo *TTI; + // TODO: we should split this list into per-candidate-kind lists. ilist Candidates; // Temporarily holds all instructions that are unlinked (but not deleted) by // rewriteCandidateWithBasis. These instructions will be actually removed @@ -191,6 +205,14 @@ }; } // anonymous namespace +namespace llvm { +inline raw_ostream &operator<<(raw_ostream &OS, + const StraightLineStrengthReduce::Candidate &C) { + C.print(OS); + return OS; +} +} + char StraightLineStrengthReduce::ID = 0; INITIALIZE_PASS_BEGIN(StraightLineStrengthReduce, "slsr", "Straight line strength reduction", false, false) @@ -204,15 +226,60 @@ return new StraightLineStrengthReduce(); } +void StraightLineStrengthReduce::Candidate::print(raw_ostream &OS) const { + switch (CandidateKind) { + case Mul: OS << "[slsr-mul]"; break; + case GEP: OS << "[slsr-gep]"; break; + case TernaryAdd: OS << "[slsr-ternary-add]"; break; + default: llvm_unreachable("CandidateKind is invalid"); + } + OS << " base: " << *Base << ";"; + OS << " index: " << *Index << ";"; + OS << " stride: "; + if (Stride == nullptr) + OS << ""; + else + OS << *Stride; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void StraightLineStrengthReduce::dumpAllCandidates() const { + for (auto &C : Candidates) + dbgs() << C << "\n"; +} +#endif + bool StraightLineStrengthReduce::isBasisFor(const Candidate &Basis, const Candidate &C) { - return (Basis.Ins != C.Ins && // skip the same instruction - // Basis must dominate C in order to rewrite C with respect to Basis. - DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) && - // They share the same base, stride, and candidate kind. - Basis.Base == C.Base && - Basis.Stride == C.Stride && - Basis.CandidateKind == C.CandidateKind); + // Skip the same instruction. + if (Basis.Ins == C.Ins) + return false; + + // Basis must dominate C in order to rewrite C with respect to Basis. + if (!DT->dominates(Basis.Ins->getParent(), C.Ins->getParent())) + return false; + + // They must share the same candidate kind. + if (Basis.CandidateKind != C.CandidateKind) + return false; + + // They share the same base. + if (Basis.Base != C.Base) + return false; + + // They share the same stride considering wildcards. + if (Basis.Stride != nullptr && C.Stride != nullptr && + Basis.Stride != C.Stride) + return false; + + // For ternary-add candidates, we expect Basis has index zero and C has index + // one; otherwise, the rewriting won't be beneficial. + if (C.CandidateKind == Candidate::TernaryAdd) { + if (!(Basis.Index->isZero() && C.Index->isOne())) + return false; + } + + return true; } static bool isCompletelyFoldable(GetElementPtrInst *GEP, @@ -262,13 +329,6 @@ void StraightLineStrengthReduce::allocateCandidateAndFindBasis( Candidate::Kind CT, const SCEV *B, ConstantInt *Idx, Value *S, Instruction *I) { - if (GetElementPtrInst *GEP = dyn_cast(I)) { - // If &B[Idx * S] fits into an addressing mode, do not turn it into - // non-free computation. - if (isCompletelyFoldable(GEP, TTI, DL)) - return; - } - Candidate C(CT, B, Idx, S, I); // Try to compute the immediate basis of C. unsigned NumIterations = 0; @@ -293,11 +353,55 @@ allocateCandidateAndFindBasisForMul(I); break; case Instruction::GetElementPtr: - allocateCandidateAndFindBasisForGEP(cast(I)); + // If &B[Idx * S] fits into an addressing mode, do not turn it into + // non-free computation. + if (!isCompletelyFoldable(cast(I), TTI, DL)) + allocateCandidateAndFindBasisForGEP(cast(I)); + break; + case Instruction::Add: + allocateCandidateAndFindBasisForAdd(I); break; } } +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForAdd( + Instruction *I) { + // TODO: skip vector types for now. + if (I->getType()->isVectorTy()) + return; + + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + const SCEV *LHSExpr = SE->getSCEV(LHS), *RHSExpr = SE->getSCEV(RHS); + ConstantInt *One = ConstantInt::get(cast(I->getType()), 1); + ConstantInt *Zero = ConstantInt::get(cast(I->getType()), 0); + Value *A = nullptr, *B = nullptr; + // To be conservative, we reassociate I only when it is the only user of + // (A + B). + if (LHS->hasOneUse() && match(LHS, m_Add(m_Value(A), m_Value(B)))) { + // I = (A + B) + RHS + // = (A + RHS) + B or (B + RHS) + A + const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); + allocateCandidateAndFindBasis( + Candidate::TernaryAdd, SE->getAddExpr(AExpr, RHSExpr), One, B, I); + allocateCandidateAndFindBasis( + Candidate::TernaryAdd, SE->getAddExpr(BExpr, RHSExpr), One, A, I); + } else if (RHS->hasOneUse() && match(RHS, m_Add(m_Value(A), m_Value(B)))) { + // I = LHS + (A + B) + // = (A + LHS) + B or (B + LHS) + A + const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); + allocateCandidateAndFindBasis( + Candidate::TernaryAdd, SE->getAddExpr(AExpr, LHSExpr), One, B, I); + allocateCandidateAndFindBasis( + Candidate::TernaryAdd, SE->getAddExpr(BExpr, LHSExpr), One, A, I); + } + // Even I is not a ternary add, it can serve as the basis of other candidates. + // Therefore, we add the form + // I = (LHS + RHS) + 0 * + // to the candidate list. + allocateCandidateAndFindBasis(Candidate::TernaryAdd, SE->getSCEV(I), Zero, + nullptr, I); +} + void StraightLineStrengthReduce::allocateCandidateAndFindBasisForMul( Value *LHS, Value *RHS, Instruction *I) { Value *B = nullptr; @@ -429,6 +533,12 @@ unifyBitWidth(Idx, BasisIdx); APInt IndexOffset = Idx - BasisIdx; + Value *Stride = C.Stride; + if (Stride == nullptr) + Stride = Basis.Stride; + assert(Stride != nullptr && + "at least one of the two strides should be non-wildcard"); + BumpWithUglyGEP = false; if (Basis.CandidateKind == Candidate::GEP) { APInt ElementSize( @@ -445,21 +555,18 @@ // Compute Bump = C - Basis = (i' - i) * S. // Common case 1: if (i' - i) is 1, Bump = S. if (IndexOffset.getSExtValue() == 1) - return C.Stride; + return Stride; // Common case 2: if (i' - i) is -1, Bump = -S. if (IndexOffset.getSExtValue() == -1) - return Builder.CreateNeg(C.Stride); + return Builder.CreateNeg(Stride); // Otherwise, Bump = (i' - i) * sext/trunc(S). ConstantInt *Delta = ConstantInt::get(Basis.Ins->getContext(), IndexOffset); - Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, Delta->getType()); + Value *ExtendedStride = Builder.CreateSExtOrTrunc(Stride, Delta->getType()); return Builder.CreateMul(ExtendedStride, Delta); } void StraightLineStrengthReduce::rewriteCandidateWithBasis( const Candidate &C, const Candidate &Basis) { - assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base && - C.Stride == Basis.Stride); - // An instruction can correspond to multiple candidates. Therefore, instead of // simply deleting an instruction when we rewrite it, we mark its parent as // nullptr (i.e. unlink it) so that we can skip the candidates whose @@ -472,6 +579,7 @@ Value *Bump = emitBump(Basis, C, Builder, DL, BumpWithUglyGEP); Value *Reduced = nullptr; // equivalent to but weaker than C.Ins switch (C.CandidateKind) { + case Candidate::TernaryAdd: case Candidate::Mul: Reduced = Builder.CreateAdd(Basis.Ins, Bump); break; @@ -528,6 +636,8 @@ allocateCandidateAndFindBasis(&I); } + DEBUG(dumpAllCandidates()); + // Rewrite candidates in the reverse depth-first order. This order makes sure // a candidate being rewritten is not a basis for any other candidate. while (!Candidates.empty()) { Index: test/Transforms/StraightLineStrengthReduce/slsr-add.ll =================================================================== --- /dev/null +++ test/Transforms/StraightLineStrengthReduce/slsr-add.ll @@ -0,0 +1,53 @@ +; RUN: opt < %s -slsr -gvn -dce -S | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" + +declare void @foo(i32 %a) + +; foo(a + c); +; foo((a + b) + c); +; foo(a + (b + c)); +; => +; t1 = a + c; +; foo(t); +; t2 = t1 + b; +; foo(t2); +; foo(t2); +define void @reassociate(i32 %a, i32 %b, i32 %c) { +; CHECK-LABEL: @reassociate( + %1 = add i32 %a, %c +; CHECK: [[BASE:%[a-zA-Z0-9]+]] = add i32 %a, %c + call void @foo(i32 %1) + %2 = add i32 %a, %b + %3 = add i32 %2, %c +; CHECK: [[COMMON:%[a-zA-Z0-9]+]] = add i32 [[BASE]], %b + call void @foo(i32 %3) +; CHECK: call void @foo(i32 [[COMMON]]) + %4 = add i32 %b, %c + %5 = add i32 %a, %4 + call void @foo(i32 %5) +; CHECK: call void @foo(i32 [[COMMON]]) + ret void +} + +; t1 = a + c; +; foo(t1); +; t2 = a + b; +; foo(t2); +; t3 = t2 + c; +; foo(t3); +; +; Do not rewrite t3 into t1 + b because t2 is used elsewhere and is likely free. +define void @no_reassociate(i32 %a, i32 %b, i32 %c) { +; CHECK-LABEL: @no_reassociate( + %1 = add i32 %a, %c +; CHECK: add i32 %a, %c + call void @foo(i32 %1) + %2 = add i32 %a, %b +; CHECK: add i32 %a, %b + call void @foo(i32 %2) + %3 = add i32 %2, %c +; CHECK: add i32 %2, %c + call void @foo(i32 %3) + ret void +}