Index: lib/Transforms/Scalar/StraightLineStrengthReduce.cpp =================================================================== --- lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -15,19 +15,30 @@ // // There are many optimizations we can perform in the domain of SLSR. This file // for now contains only an initial step. Specifically, we look for strength -// reduction candidate in the form of +// reduction candidates in two forms: // -// (B + i) * S +// Form 1: (B + i) * S +// Form 2: &B[i * S] // -// where B and S are integer constants or variables, and i is a constant -// integer. If we found two such candidates +// where S is an integer variable, and i is a constant integer. If we found two +// candidates // -// S1: X = (B + i) * S S2: Y = (B + i') * S +// S1: X = (B + i) * S +// S2: Y = (B + i') * S +// +// or +// +// S1: X = &B[i * S] +// S2: Y = &B[i' * S] // // and S1 dominates S2, we call S1 a basis of S2, and can replace S2 with // // Y = X + (i' - i) * S // +// or +// +// Y = &X[(i' - i) * S] +// // where (i' - i) * S is folded to the extent possible. When S2 has multiple // bases, we pick the one that is closest to S2, or S2's "immediate" basis. // @@ -35,8 +46,6 @@ // // - Handle candidates in the form of B + i * S // -// - Handle candidates in the form of pointer arithmetics. e.g., B[i * S] -// // - Floating point arithmetics when fast math is enabled. // // - SLSR may decrease ILP at the architecture level. Targets that are very @@ -45,6 +54,9 @@ #include #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/FoldingSet.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" @@ -58,14 +70,49 @@ namespace { class StraightLineStrengthReduce : public FunctionPass { - public: +public: + // Represents the base expression used in a candidate. There are two types of + // base expressions so far: + // 1. For the candidate form "(B + i) * S", a base expression is simply the + // LLVM value of B. + // 2. For the candidate form "&B[i * S]", a base expression is the inner + // pointer B represented by a prefix of a GEP instruction. For example, + // the base expression of "getelementptr a, 1, 2, i * S" is "getelementptr + // a, 1, 2", which is uniquely identified by list [a, 1, 2]. + class BaseExpr : public FoldingSetNode { + private: + SmallVector Operands; + + public: + // for the candidate form "(B + i) * S" + BaseExpr(Value *Opnd) : Operands(1, Opnd) {} + + // for the candidate form "&B[i * S]" + template BaseExpr(ItTy B, ItTy E) : Operands(B, E) {} + + void Profile(FoldingSetNodeID &ID) { + for (auto Operand : Operands) + ID.AddPointer(Operand); + } + }; + // SLSR candidate. Such a candidate must be in the form of // (Base + Index) * Stride struct Candidate : public ilist_node { - Candidate(Value *B = nullptr, ConstantInt *Idx = nullptr, - Value *S = nullptr, Instruction *I = nullptr) - : Base(B), Index(Idx), Stride(S), Ins(I), Basis(nullptr) {} - Value *Base; + enum Type { + Invalid, // reserved for the default constructor + Mul, // (B + i) * S + GEP, // &B[i * S] + }; + + Candidate() + : CandidateType(Invalid), Base(nullptr), Index(nullptr), + Stride(nullptr), Ins(nullptr), Basis(nullptr) {} + Candidate(Type CT, BaseExpr *B, ConstantInt *Idx, Value *S, Instruction *I) + : CandidateType(CT), Base(B), Index(Idx), Stride(S), Ins(I), + Basis(nullptr) {} + Type CandidateType; + BaseExpr *Base; ConstantInt *Index; Value *Stride; // The instruction this candidate corresponds to. It helps us to rewrite a @@ -90,45 +137,73 @@ static char ID; - StraightLineStrengthReduce() : FunctionPass(ID), DT(nullptr) { + StraightLineStrengthReduce() + : FunctionPass(ID), DL(nullptr), DT(nullptr), TTI(nullptr) { initializeStraightLineStrengthReducePass(*PassRegistry::getPassRegistry()); } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); + AU.addRequired(); // We do not modify the shape of the CFG. AU.setPreservesCFG(); } + bool doInitialization(Module &M) override { + DataLayoutPass *DLP = getAnalysisIfAvailable(); + if (DLP == nullptr) + report_fatal_error("data layout missing"); + DL = &DLP->getDataLayout(); + return false; + } + bool runOnFunction(Function &F) override; - private: +private: // Returns true if Basis is a basis for C, i.e., Basis dominates C and they // share the same base and stride. bool isBasisFor(const Candidate &Basis, const Candidate &C); // Checks whether I is in a candidate form. If so, adds all the matching forms // to Candidates, and tries to find the immediate basis for each of them. void allocateCandidateAndFindBasis(Instruction *I); + // Allocate candidates and find bases for Mul instructions. + void allocateCandidateAndFindBasisForMul(BinaryOperator *I); + // Allocate candidates and find bases for GetElementPtr instructions. + void allocateCandidateAndFindBasisForGEP(GetElementPtrInst *GEP); + void allocateCandidateAndFindBasis(Candidate::Type CT, Value *B, + ConstantInt *Idx, Value *S, + Instruction *I); // Given that I is in the form of "(B + Idx) * S", adds this form to // Candidates, and finds its immediate basis. - void allocateCandidateAndFindBasis(Value *B, ConstantInt *Idx, Value *S, + void allocateCandidateAndFindBasis(Candidate::Type CT, BaseExpr *B, + ConstantInt *Idx, Value *S, Instruction *I); // Rewrites candidate C with respect to Basis. void rewriteCandidateWithBasis(const Candidate &C, const Candidate &Basis); + const DataLayout *DL; DominatorTree *DT; + TargetTransformInfo *TTI; ilist Candidates; // Temporarily holds all instructions that are unlinked (but not deleted) by // rewriteCandidateWithBasis. These instructions will be actually removed // after all rewriting finishes. DenseSet UnlinkedInstructions; + // To efficiently decide whether two BaseExprs are equivalent, + // we use a FoldingSet to keep unique BaseExprs. For example, + // the base expression of "getelementptr a, 1, 2, i" and that of + // "getelementptr a, 1, 2, i'" share the same folding set node. + FoldingSet UniqueBaseExpressions; }; } // anonymous namespace char StraightLineStrengthReduce::ID = 0; INITIALIZE_PASS_BEGIN(StraightLineStrengthReduce, "slsr", "Straight line strength reduction", false, false) +INITIALIZE_PASS_DEPENDENCY(DataLayoutPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(StraightLineStrengthReduce, "slsr", "Straight line strength reduction", false, false) @@ -141,9 +216,10 @@ return (Basis.Ins != C.Ins && // skip the same instruction // Basis must dominate C in order to rewrite C with respect to Basis. DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) && - // They share the same base and stride. + // They share the same base, stride, and candidate type. Basis.Base == C.Base && - Basis.Stride == C.Stride); + Basis.Stride == C.Stride && + Basis.CandidateType == C.CandidateType); } // TODO: We currently implement an algorithm whose time complexity is linear to @@ -153,11 +229,25 @@ // table is indexed by the base and the stride of a candidate. Therefore, // finding the immediate basis of a candidate boils down to one hash-table look // up. -void StraightLineStrengthReduce::allocateCandidateAndFindBasis(Value *B, - ConstantInt *Idx, - Value *S, - Instruction *I) { - Candidate C(B, Idx, S, I); +void StraightLineStrengthReduce::allocateCandidateAndFindBasis( + Candidate::Type CT, BaseExpr *B, ConstantInt *Idx, Value *S, + Instruction *I) { + assert(Idx->getType() == S->getType()); + + if (GetElementPtrInst *GEP = dyn_cast(I)) { + // If &B[Idx * S] fits into an addressing mode, do not turn it into + // non-free computation. + if (TTI && + TTI->isLegalAddressingMode( + GEP->getType(), /*BaseGV=*/nullptr, /*BaseOffset=*/0, + /*HasBaseReg=*/true, + /*Scale=*/DL->getTypeAllocSize(GEP->getType()->getElementType()) * + Idx->getSExtValue())) { + return; + } + } + + Candidate C(CT, B, Idx, S, I); // Try to compute the immediate basis of C. unsigned NumIterations = 0; // Limit the scan radius to avoid running forever. @@ -175,61 +265,154 @@ Candidates.push_back(C); } +void StraightLineStrengthReduce::allocateCandidateAndFindBasis( + Candidate::Type CT, Value *B, ConstantInt *Idx, Value *S, Instruction *I) { + BaseExpr *Base = new BaseExpr(B); + BaseExpr *UniqueBase = UniqueBaseExpressions.GetOrInsertNode(Base); + if (UniqueBase != Base) + delete Base; + allocateCandidateAndFindBasis(CT, UniqueBase, Idx, S, I); +} + void StraightLineStrengthReduce::allocateCandidateAndFindBasis(Instruction *I) { - Value *B = nullptr; - ConstantInt *Idx = nullptr; - // "(Base + Index) * Stride" must be a Mul instruction at the first hand. - if (I->getOpcode() == Instruction::Mul) { - if (IntegerType *ITy = dyn_cast(I->getType())) { - Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); - for (unsigned Swapped = 0; Swapped < 2; ++Swapped) { - // Only handle the canonical operand ordering. - if (match(LHS, m_Add(m_Value(B), m_ConstantInt(Idx)))) { - // If LHS is in the form of "Base + Index", then I is in the form of - // "(Base + Index) * RHS". - allocateCandidateAndFindBasis(B, Idx, RHS, I); - } else { - // Otherwise, at least try the form (LHS + 0) * RHS. - allocateCandidateAndFindBasis(LHS, ConstantInt::get(ITy, 0), RHS, I); - } - // Swap LHS and RHS so that we also cover the cases where LHS is the - // stride. - if (LHS == RHS) - break; - std::swap(LHS, RHS); - } + switch (I->getOpcode()) { + case Instruction::Mul: + allocateCandidateAndFindBasisForMul(cast(I)); + break; + case Instruction::GetElementPtr: + allocateCandidateAndFindBasisForGEP(cast(I)); + break; + } +} + +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForMul( + BinaryOperator *I) { + // Try matching (B + i) * S. + // TODO: we could extend SLSR to float and vector types. + IntegerType *ITy = dyn_cast(I->getType()); + if (!ITy) + return; + + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + for (unsigned Swapped = 0; Swapped < 2; ++Swapped) { + Value *B = nullptr; + ConstantInt *Idx = nullptr; + // Only handle the canonical operand ordering. + if (match(LHS, m_Add(m_Value(B), m_ConstantInt(Idx)))) { + // If LHS is in the form of "Base + Index", then I is in the form of + // "(Base + Index) * RHS". + allocateCandidateAndFindBasis(Candidate::Mul, B, Idx, RHS, I); + } else { + // Otherwise, at least try the form (LHS + 0) * RHS. + allocateCandidateAndFindBasis(Candidate::Mul, LHS, + ConstantInt::get(ITy, 0), RHS, I); } + // Swap LHS and RHS so that we also cover the cases where LHS is the stride. + if (LHS == RHS) + break; + std::swap(LHS, RHS); + } +} + +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForGEP( + GetElementPtrInst *I) { + // Try matching B[...][i * S]. + // TODO: we exploit the SLSR opportunities only in the last index, because + // this case is most common. + assert(I->getNumOperands() > 1 && "A GEP should have at least one index"); + + // LastType is the type of the last dimension, i.e., typeof(B[...]). SLSR + // needs to make sure this type is an array type (i.e., SequentialType). + gep_type_iterator LastType = gep_type_begin(*I); + for (unsigned Index = 0; Index < I->getNumIndices() - 1; ++Index, ++LastType); + if (!isa(*LastType)) { + return; + } + + Value *LastIdx = *(I->value_op_end() - 1); + if (!isa(LastIdx->getType())) + return; + + BaseExpr *Base = new BaseExpr(I->value_op_begin(), I->value_op_end() - 1); + BaseExpr *UniqueBase = UniqueBaseExpressions.GetOrInsertNode(Base); + if (UniqueBase != Base) // Base already exists. + delete Base; + + for (unsigned Stripped = 0; Stripped < 2; ++Stripped) { + // Factorize LastIdx to the form of i * S. + // At least LastIdx = 1 * LastIdx. + allocateCandidateAndFindBasis( + Candidate::GEP, UniqueBase, + ConstantInt::get(cast(LastIdx->getType()), 1), LastIdx, I); + Value *LHS = nullptr; + ConstantInt *RHS = nullptr; + // TODO: handle shl. e.g., we could treat (S << 2) as (S * 4). + if (match(LastIdx, m_Mul(m_Value(LHS), m_ConstantInt(RHS)))) { + // This transformation is unsafe if i * S may overflow. + if (cast(LastIdx)->hasNoSignedWrap()) + allocateCandidateAndFindBasis(Candidate::GEP, UniqueBase, RHS, LHS, I); + } + // Strips the sext from LastIdx and try factoring again. + if (!match(LastIdx, m_SExt(m_Value(LastIdx)))) + break; } } void StraightLineStrengthReduce::rewriteCandidateWithBasis( const Candidate &C, const Candidate &Basis) { + assert(C.CandidateType == Basis.CandidateType && C.Base == Basis.Base && + C.Stride == Basis.Stride); + // An instruction can correspond to multiple candidates. Therefore, instead of // simply deleting an instruction when we rewrite it, we mark its parent as // nullptr (i.e. unlink it) so that we can skip the candidates whose // instruction is already rewritten. if (!C.Ins->getParent()) return; - assert(C.Base == Basis.Base && C.Stride == Basis.Stride); - // Basis = (B + i) * S - // C = (B + i') * S - // ==> - // C = Basis + (i' - i) * S + IRBuilder<> Builder(C.Ins); ConstantInt *IndexOffset = ConstantInt::get( C.Ins->getContext(), C.Index->getValue() - Basis.Index->getValue()); - Value *Reduced; - // TODO: preserve nsw/nuw in some cases. + // Compute Bump = C - Basis = (i' - i) * S. + Value *Bump = nullptr; if (IndexOffset->isOne()) { - // If (i' - i) is 1, fold C into Basis + S. - Reduced = Builder.CreateAdd(Basis.Ins, C.Stride); + // If (i' - i) is 1, Bump = S. + Bump = C.Stride; } else if (IndexOffset->isMinusOne()) { - // If (i' - i) is -1, fold C into Basis - S. - Reduced = Builder.CreateSub(Basis.Ins, C.Stride); + // If (i' - i) is -1, Bump = -S. + Bump = Builder.CreateNeg(C.Stride); } else { - Value *Bump = Builder.CreateMul(C.Stride, IndexOffset); - Reduced = Builder.CreateAdd(Basis.Ins, Bump); + // Otherwise, Bump = (i' - i) * S. + Bump = Builder.CreateMul(C.Stride, IndexOffset); } + + Value *Reduced = nullptr; + switch (C.CandidateType) { + case Candidate::Mul: + // Basis = (B + i) * S + // C = (B + i') * S + // ==> + // C = Basis + (i' - i) * S + Reduced = Builder.CreateAdd(Basis.Ins, Bump); + break; + case Candidate::GEP: + // Basis = &B[..][i * S] + // C = &B[..][i' * S] + // ==> + // C = &Basis[(i' - i) * S] + { + Type *IntPtrTy = DL->getIntPtrType(C.Ins->getType()); + // Canonicalize bump to pointer size. + Bump = Builder.CreateSExtOrTrunc(Bump, IntPtrTy); + if (cast(C.Ins)->isInBounds()) + Reduced = Builder.CreateInBoundsGEP(Basis.Ins, Bump); + else + Reduced = Builder.CreateGEP(Basis.Ins, Bump); + } + break; + default: + assert(false && "CandidateType is invalid"); + }; Reduced->takeName(C.Ins); C.Ins->replaceAllUsesWith(Reduced); C.Ins->dropAllReferences(); @@ -243,6 +426,14 @@ if (skipOptnoneFunction(F)) return false; + if (TargetTransformInfoWrapperPass *TTIP = + getAnalysisIfAvailable()) { + TTI = &TTIP->getTTI(F); + } else { + TTI = nullptr; + } + + UniqueBaseExpressions.clear(); DT = &getAnalysis().getDomTree(); // Traverse the dominator tree in the depth-first order. This order makes sure // all bases of a candidate are in Candidates when we process it. Index: test/Transforms/StraightLineStrengthReduce/X86/lit.local.cfg =================================================================== --- /dev/null +++ test/Transforms/StraightLineStrengthReduce/X86/lit.local.cfg @@ -0,0 +1,2 @@ +if not 'X86' in config.root.targets: + config.unsupported = True Index: test/Transforms/StraightLineStrengthReduce/X86/no-slsr.ll =================================================================== --- /dev/null +++ test/Transforms/StraightLineStrengthReduce/X86/no-slsr.ll @@ -0,0 +1,30 @@ +; RUN: opt < %s -slsr -gvn -dce -S | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +; Do not perform SLSR on &input[s] and &input[s * 2] which fit into addressing +; modes of X86. +define i32 @slsr_gep(i32* %input, i64 %s) { +; CHECK-LABEL: @slsr_gep( + ; v0 = input[0]; + %p0 = getelementptr inbounds i32* %input, i64 0 + %v0 = load i32* %p0 + + ; v1 = input[s]; + %p1 = getelementptr inbounds i32* %input, i64 %s +; CHECK: %p1 = getelementptr inbounds i32* %input, i64 %s + %v1 = load i32* %p1 + + ; v2 = input[s * 2]; + %s2 = mul nsw i64 %s, 2 + %p2 = getelementptr inbounds i32* %input, i64 %s2 +; CHECK: %p2 = getelementptr inbounds i32* %input, i64 %s2 + %v2 = load i32* %p2 + + ; return v0 + v1 + v2; + %1 = add i32 %v0, %v1 + %2 = add i32 %1, %v2 + ret i32 %2 +} + Index: test/Transforms/StraightLineStrengthReduce/slsr-gep.ll =================================================================== --- /dev/null +++ test/Transforms/StraightLineStrengthReduce/slsr-gep.ll @@ -0,0 +1,51 @@ +; RUN: opt < %s -slsr -gvn -dce -S | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" + +define i32 @slsr_gep(i32* %input, i64 %s) { +; CHECK-LABEL: @slsr_gep( + ; v0 = input[0]; + %p0 = getelementptr inbounds i32* %input, i64 0 + %v0 = load i32* %p0 + + ; v1 = input[s]; + %p1 = getelementptr inbounds i32* %input, i64 %s +; CHECK: %p1 = getelementptr inbounds i32* %input, i64 %s + %v1 = load i32* %p1 + + ; v2 = input[s * 2]; + %s2 = mul nsw i64 %s, 2 + %p2 = getelementptr inbounds i32* %input, i64 %s2 +; CHECK: %p2 = getelementptr inbounds i32* %p1, i64 %s + %v2 = load i32* %p2 + + ; return v0 + v1 + v2; + %1 = add i32 %v0, %v1 + %2 = add i32 %1, %v2 + ret i32 %2 +} + +define i32 @slsr_gep_sext(i32* %input, i32 %s) { +; CHECK-LABEL: @slsr_gep_sext( + ; v0 = input[0]; + %p0 = getelementptr inbounds i32* %input, i64 0 + %v0 = load i32* %p0 + + ; v1 = input[(long)s]; + %t = sext i32 %s to i64 + %p1 = getelementptr inbounds i32* %input, i64 %t +; CHECK: %p1 = getelementptr inbounds i32* %input, i64 %t + %v1 = load i32* %p1 + + ; v2 = input[(long)(s * 2)]; + %s2 = mul nsw i32 %s, 2 + %t2 = sext i32 %s2 to i64 + %p2 = getelementptr inbounds i32* %input, i64 %t2 +; CHECK: %p2 = getelementptr inbounds i32* %p1, i64 %t + %v2 = load i32* %p2 + + ; return v0 + v1 + v2; + %1 = add i32 %v0, %v1 + %2 = add i32 %1, %v2 + ret i32 %2 +} Index: test/Transforms/StraightLineStrengthReduce/slsr-mul.ll =================================================================== --- test/Transforms/StraightLineStrengthReduce/slsr-mul.ll +++ test/Transforms/StraightLineStrengthReduce/slsr-mul.ll @@ -1,5 +1,7 @@ ; RUN: opt < %s -slsr -gvn -dce -S | FileCheck %s +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" + declare i32 @foo(i32 %a) define i32 @slsr1(i32 %b, i32 %s) {