Index: lib/Transforms/Scalar/NaryReassociate.cpp =================================================================== --- lib/Transforms/Scalar/NaryReassociate.cpp +++ lib/Transforms/Scalar/NaryReassociate.cpp @@ -74,21 +74,18 @@ // 1) We only considers n-ary adds for now. This should be extended and // generalized. // -// 2) Besides arithmetic operations, similar reassociation can be applied to -// GEPs. For example, if -// X = &arr[a] -// dominates -// Y = &arr[a + b] -// we may rewrite Y into X + b. -// //===----------------------------------------------------------------------===// +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -115,6 +112,7 @@ AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); @@ -163,12 +161,18 @@ // GEP's pointer size, i.e., whether Index needs to be sign-extended in order // to be an index of GEP. bool requiresSignExtension(Value *Index, GetElementPtrInst *GEP); + // Returns whether V is known to be non-negative at context \c Ctxt. + bool isKnownNonNegative(Value *V, Instruction *Ctxt); + // Returns whether AO may sign overflow at context \c Ctxt. It computes a + // conservative result -- it answers true when not sure. + bool maySignOverflow(AddOperator *AO, Instruction *Ctxt); + AssumptionCache *AC; + const DataLayout *DL; DominatorTree *DT; ScalarEvolution *SE; TargetLibraryInfo *TLI; TargetTransformInfo *TTI; - const DataLayout *DL; // A lookup table quickly telling which instructions compute the given SCEV. // Note that there can be multiple instructions at different locations // computing to the same SCEV, so we map a SCEV to an instruction list. For @@ -185,6 +189,7 @@ char NaryReassociate::ID = 0; INITIALIZE_PASS_BEGIN(NaryReassociate, "nary-reassociate", "Nary reassociation", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolution) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) @@ -200,6 +205,7 @@ if (skipOptnoneFunction(F)) return false; + AC = &getAnalysis().getAssumptionCache(F); DT = &getAnalysis().getDomTree(); SE = &getAnalysis(); TLI = &getAnalysis().getTLI(); @@ -346,18 +352,44 @@ return cast(Index->getType())->getBitWidth() < PointerSizeInBits; } +bool NaryReassociate::isKnownNonNegative(Value *V, Instruction *Ctxt) { + bool NonNegative, Negative; + // TODO: ComputeSignBits is expensive. Consider caching the results. + ComputeSignBit(V, NonNegative, Negative, *DL, 0, AC, Ctxt, DT); + return NonNegative; +} + +bool NaryReassociate::maySignOverflow(AddOperator *AO, Instruction *Ctxt) { + if (AO->hasNoSignedWrap()) + return false; + + Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1); + // If LHS or RHS has the same sign as the sum, AO doesn't sign overflow. + // TODO: handle the negative case as well. + if (isKnownNonNegative(AO, Ctxt) && + (isKnownNonNegative(LHS, Ctxt) || isKnownNonNegative(RHS, Ctxt))) + return false; + + return true; +} + GetElementPtrInst * NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I, Type *IndexedType) { Value *IndexToSplit = GEP->getOperand(I + 1); - if (SExtInst *SExt = dyn_cast(IndexToSplit)) + if (SExtInst *SExt = dyn_cast(IndexToSplit)) { IndexToSplit = SExt->getOperand(0); + } else if (ZExtInst *ZExt = dyn_cast(IndexToSplit)) { + // zext can be treated as sext if the source is non-negative. + if (isKnownNonNegative(ZExt->getOperand(0), GEP)) + IndexToSplit = ZExt->getOperand(0); + } if (AddOperator *AO = dyn_cast(IndexToSplit)) { // If the I-th index needs sext and the underlying add is not equipped with // nsw, we cannot split the add because // sext(LHS + RHS) != sext(LHS) + sext(RHS). - if (requiresSignExtension(IndexToSplit, GEP) && !AO->hasNoSignedWrap()) + if (requiresSignExtension(IndexToSplit, GEP) && maySignOverflow(AO, GEP)) return nullptr; Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1); // IndexToSplit = LHS + RHS. @@ -373,10 +405,9 @@ return nullptr; } -GetElementPtrInst * -NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I, - Value *LHS, Value *RHS, - Type *IndexedType) { +GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( + GetElementPtrInst *GEP, unsigned I, Value *LHS, Value *RHS, + Type *IndexedType) { // Look for GEP's closest dominator that has the same SCEV as GEP except that // the I-th index is replaced with LHS. SmallVector IndexExprs; @@ -384,6 +415,16 @@ IndexExprs.push_back(SE->getSCEV(*Index)); // Replace the I-th index with LHS. IndexExprs[I] = SE->getSCEV(LHS); + if (isKnownNonNegative(LHS, GEP) && + DL->getTypeSizeInBits(LHS->getType()) < + DL->getTypeSizeInBits(GEP->getOperand(I)->getType())) { + // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to + // zext if the source operand is proved non-negative. We should do that + // consistently so that CandidateExpr more likely appears before. See + // @reassociate_gep_assume for an example of this canonicalization. + IndexExprs[I] = + SE->getZeroExtendExpr(IndexExprs[I], GEP->getOperand(I)->getType()); + } const SCEV *CandidateExpr = SE->getGEPExpr( GEP->getSourceElementType(), SE->getSCEV(GEP->getPointerOperand()), IndexExprs, GEP->isInBounds()); Index: test/Transforms/NaryReassociate/NVPTX/nary-gep.ll =================================================================== --- test/Transforms/NaryReassociate/NVPTX/nary-gep.ll +++ test/Transforms/NaryReassociate/NVPTX/nary-gep.ll @@ -36,27 +36,61 @@ ; t3 = t2 + sext(i); // sext(i) should be GVN'ed. ; foo(t3); define void @reassociate_gep_nsw(float* %a, i32 %i, i32 %j) { -; CHECK-LABEL: @reassociate_gep_nsw( +; check-label: @reassociate_gep_nsw( %idxprom.j = sext i32 %j to i64 %1 = getelementptr float, float* %a, i64 %idxprom.j -; CHECK: [[t1:[^ ]+]] = getelementptr float, float* %a, i64 %idxprom.j +; check: [[t1:[^ ]+]] = getelementptr float, float* %a, i64 %idxprom.j call void @foo(float* %1) -; CHECK: call void @foo(float* [[t1]]) +; check: call void @foo(float* [[t1]]) %2 = add nsw i32 %i, %j %idxprom.2 = sext i32 %2 to i64 %3 = getelementptr float, float* %a, i64 %idxprom.2 -; CHECK: [[sexti:[^ ]+]] = sext i32 %i to i64 -; CHECK: [[t2:[^ ]+]] = getelementptr float, float* [[t1]], i64 [[sexti]] +; check: [[sexti:[^ ]+]] = sext i32 %i to i64 +; check: [[t2:[^ ]+]] = getelementptr float, float* [[t1]], i64 [[sexti]] call void @foo(float* %3) -; CHECK: call void @foo(float* [[t2]]) +; check: call void @foo(float* [[t2]]) %4 = add nsw i32 %2, %i %idxprom.4 = sext i32 %4 to i64 %5 = getelementptr float, float* %a, i64 %idxprom.4 -; CHECK: [[t3:[^ ]+]] = getelementptr float, float* [[t2]], i64 [[sexti]] +; check: [[t3:[^ ]+]] = getelementptr float, float* [[t2]], i64 [[sexti]] call void @foo(float* %5) -; CHECK: call void @foo(float* [[t3]]) +; check: call void @foo(float* [[t3]]) + + ret void +} + +; assume(j >= 0); +; foo(&a[zext(j)]); +; assume(i + j >= 0); +; foo(&a[zext(i + j)]); +; => +; t1 = &a[zext(j)]; +; foo(t1); +; t2 = t1 + sext(i); +; foo(t2); +define void @reassociate_gep_assume(float* %a, i32 %i, i32 %j) { +; check-label: @reassociate_gep_assume( + ; assume(j >= 0) + %cmp = icmp sgt i32 %j, -1 + call void @llvm.assume(i1 %cmp) + %1 = add i32 %i, %j + %cmp2 = icmp sgt i32 %1, -1 + call void @llvm.assume(i1 %cmp2) + + %idxprom.j = zext i32 %j to i64 + %2 = getelementptr float, float* %a, i64 %idxprom.j +; check: [[t1:[^ ]+]] = getelementptr float, float* %a, i64 %idxprom.j + call void @foo(float* %2) +; check: call void @foo(float* [[t1]]) + + %idxprom.1 = zext i32 %1 to i64 + %3 = getelementptr float, float* %a, i64 %idxprom.1 +; check: [[sexti:[^ ]+]] = sext i32 %i to i64 +; check: [[t2:[^ ]+]] = getelementptr float, float* [[t1]], i64 [[sexti]] + call void @foo(float* %3) +; check: call void @foo(float* [[t2]]) ret void } @@ -88,3 +122,5 @@ ; CHECK: call void @foo(float* [[t2]]) ret void } + +declare void @llvm.assume(i1)