diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -512,6 +512,7 @@ PHINode *Phi, const InductionDescriptor &ID, SmallPtrSetImpl &AllowedExit) { Inductions[Phi] = ID; + Type *PhiTy = Phi->getType(); // In case this induction also comes with casts that we know we can ignore // in the vectorized loop body, record them here. All casts could be recorded @@ -520,8 +521,15 @@ const SmallVectorImpl &Casts = ID.getCastInsts(); if (!Casts.empty()) InductionCastsToIgnore.insert(*Casts.begin()); + for (Value *O : cast(Phi)->users()) + if (auto *I = dyn_cast(O)) + if (TheLoop->contains(I)) { + if (auto *Cast = dyn_cast(I)) + if (Cast->getSrcTy() == PhiTy) { + InductionCastsToIgnore.insert(Cast); + } + } - Type *PhiTy = Phi->getType(); const DataLayout &DL = Phi->getModule()->getDataLayout(); // Get the widest type. diff --git a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt --- a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt +++ b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt @@ -12,4 +12,5 @@ VPlanTest.cpp VPlanHCFGTest.cpp VPlanSlpTest.cpp + LoopVectorizationLegalityTest.cpp ) diff --git a/llvm/unittests/Transforms/Vectorize/LoopVectorizationLegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/LoopVectorizationLegalityTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Vectorize/LoopVectorizationLegalityTest.cpp @@ -0,0 +1,159 @@ +//===- LoopVectorizationLegality.cpp - Loop Legality unit tests -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "gtest/gtest.h" + +using namespace llvm; + +class LoopVectorizationLegalityTest : public testing::Test { +protected: + TargetLibraryInfoImpl TLII; + TargetLibraryInfo TLI; + DataLayout DL; + + std::unique_ptr AARes; + std::unique_ptr BasicAA; + std::unique_ptr LAI; + std::unique_ptr PSE; + std::unique_ptr ORE; + std::unique_ptr TTI; + std::unique_ptr R; + std::unique_ptr LVH; + std::unique_ptr BFI; + std::unique_ptr PSI; + std::unique_ptr> GetLAA; + std::unique_ptr IAI; + std::unique_ptr DB; + std::unique_ptr Ctx; + std::unique_ptr M; + std::unique_ptr LI; + std::unique_ptr DT; + std::unique_ptr AC; + std::unique_ptr SE; + + LoopVectorizationLegalityTest() + : TLII(), TLI(TLII), + DL("e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-" + "f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:" + "16:32:64-S128"), + Ctx(new LLVMContext) {} + + LoopVectorizationLegality getLoopVectorizationLegality(Function *F) { + TTI.reset(new TargetTransformInfo(M->getDataLayout())); + ORE.reset(new OptimizationRemarkEmitter(F)); + DT.reset(new DominatorTree(*F)); + LI.reset(new LoopInfo(*DT)); + + BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor(); + Loop *L = LI->getLoopFor(LoopHeader); + + R.reset(new LoopVectorizationRequirements()); + LVH.reset(new LoopVectorizeHints(L, false, *ORE)); + BFI.reset(new BlockFrequencyInfo()); + PSI.reset(new ProfileSummaryInfo(*M)); + + GetLAA.reset(new std::function( + [&](Loop &L) -> const LoopAccessInfo & { return *LAI; })); + + AC.reset(new AssumptionCache(*F)); + DB.reset(new DemandedBits(*F, *AC, *DT)); + SE.reset(new ScalarEvolution(*F, TLI, *AC, *DT, *LI)); + BasicAA.reset(new BasicAAResult(DL, *F, TLI, *AC, &*DT)); + AARes.reset(new AAResults(TLI)); + AARes->addAAResult(*BasicAA); + + PSE.reset(new PredicatedScalarEvolution(*SE, *L)); + LAI.reset(new LoopAccessInfo(L, &*SE, &TLI, &*AARes, &*DT, &*LI)); + + LoopVectorizationLegality Legal(L, *PSE, &*DT, &*TTI, &TLI, &*AARes, F, + &*GetLAA, &*LI, &*ORE, &*R, &*LVH, &*DB, + &*AC, &*BFI, &*PSI); + return Legal; + } + + Module &parseModule(const char *ModuleString) { + SMDiagnostic Err; + M = parseAssemblyString(ModuleString, Err, *Ctx); + EXPECT_TRUE(M); + return *M; + } +}; + +TEST_F(LoopVectorizationLegalityTest, IgnoreInductionCastTest) { + const char *ModuleString = + "@a = internal global [100 x double] zeroinitializer, align 16\n" + "@b = internal global [100 x double] zeroinitializer, align 16\n" + "define void @bar() {\n" + "entry:\n" + " br label %for.body\n" + "for.body:\n" + " %i.01 = phi i32 [ 0, %entry ], [ %inc, %for.inc ]\n" + " %idxprom = sext i32 %i.01 to i64\n" + " %arrayidx = getelementptr inbounds [100 x double], [100 x double]* " + "@a, i64 0, i64 %idxprom\n" + " %0 = load double, double* %arrayidx, align 8\n" + " %tobool = fcmp une double %0, 0.000000e+00\n" + " br i1 %tobool, label %if.then, label %for.inc\n" + "if.then:\n" + " %add = fadd double 2.000000e+00, 1.000000e+00\n" + " %idxprom1 = sext i32 %i.01 to i64\n" + " %arrayidx2 = getelementptr inbounds [100 x double], [100 x double]* " + "@b, i64 0, i64 %idxprom1\n" + " store double %add, double* %arrayidx2, align 8\n" + " br label %for.inc\n" + "for.inc:\n" + " %inc = add nsw i32 %i.01, 1\n" + " %cmp = icmp slt i32 %inc, 100\n" + " br i1 %cmp, label %for.body, label %for.end\n" + "for.end:\n" + " ret void\n" + "}"; + + parseModule(ModuleString); + Function *F = M->getFunction("bar"); + + LoopVectorizationLegality Legal = getLoopVectorizationLegality(F); + EXPECT_TRUE(Legal.canVectorize(false)); + + int count = 0; + bool found = false; + for (BasicBlock &BB : *F) { + for (Instruction &I : BB) { + Value *Ptr = getLoadStorePointerOperand(&I); + if (!Ptr) + continue; + // look for the 2nd memory op. + if (++count != 2) + continue; + + auto *Gep = dyn_cast(Ptr); + EXPECT_TRUE(Gep); + + Value *Opd = Gep->getOperand(2); + EXPECT_TRUE(Opd); + EXPECT_TRUE(Legal.isInductionVariable(Opd)); + found = true; + break; + } + if (found) + break; + } +} \ No newline at end of file