Index: llvm/trunk/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp =================================================================== --- llvm/trunk/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ llvm/trunk/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -97,8 +97,16 @@ namespace { +/// ChainID is an arbitrary token that is allowed to be different only for the +/// accesses that are guaranteed to be considered non-consecutive by +/// Vectorizer::isConsecutiveAccess. It's used for grouping instructions +/// together and reducing the number of instructions the main search operates on +/// at a time, i.e. this is to reduce compile time and nothing else as the main +/// search has O(n^2) time complexity. The underlying type of ChainID should not +/// be relied upon. +using ChainID = const Value *; using InstrList = SmallVector; -using InstrListMap = MapVector; +using InstrListMap = MapVector; class Vectorizer { Function &F; @@ -136,9 +144,15 @@ return DL.getABITypeAlignment(SI->getValueOperand()->getType()); } + static const unsigned MaxDepth = 3; + bool isConsecutiveAccess(Value *A, Value *B); - bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt Size); - bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta); + bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt PtrDelta, + unsigned Depth = 0) const; + bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta, + unsigned Depth) const; + bool lookThroughSelects(Value *PtrA, Value *PtrB, APInt PtrDelta, + unsigned Depth) const; /// After vectorization, reorder the instructions that I depends on /// (the instructions defining its operands), to ensure they dominate I. @@ -304,7 +318,8 @@ return areConsecutivePointers(PtrA, PtrB, Size); } -bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, APInt Size) { +bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, + APInt PtrDelta, unsigned Depth) const { unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(PtrA->getType()); APInt OffsetA(PtrBitWidth, 0); APInt OffsetB(PtrBitWidth, 0); @@ -316,11 +331,11 @@ // Check if they are based on the same pointer. That makes the offsets // sufficient. if (PtrA == PtrB) - return OffsetDelta == Size; + return OffsetDelta == PtrDelta; // Compute the necessary base pointer delta to have the necessary final delta - // equal to the size. - APInt BaseDelta = Size - OffsetDelta; + // equal to the pointer delta requested. + APInt BaseDelta = PtrDelta - OffsetDelta; // Compute the distance with SCEV between the base pointers. const SCEV *PtrSCEVA = SE.getSCEV(PtrA); @@ -341,15 +356,16 @@ // Sometimes even this doesn't work, because SCEV can't always see through // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking // things the hard way. - return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta); + return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth); } bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, - APInt PtrDelta) { + APInt PtrDelta, + unsigned Depth) const { auto *GEPA = dyn_cast(PtrA); auto *GEPB = dyn_cast(PtrB); if (!GEPA || !GEPB) - return false; + return lookThroughSelects(PtrA, PtrB, PtrDelta, Depth); // Look through GEPs after checking they're the same except for the last // index. @@ -434,6 +450,23 @@ return X == OffsetSCEVB; } +bool Vectorizer::lookThroughSelects(Value *PtrA, Value *PtrB, APInt PtrDelta, + unsigned Depth) const { + if (Depth++ == MaxDepth) + return false; + + if (auto *SelectA = dyn_cast(PtrA)) { + if (auto *SelectB = dyn_cast(PtrB)) { + return SelectA->getCondition() == SelectB->getCondition() && + areConsecutivePointers(SelectA->getTrueValue(), + SelectB->getTrueValue(), PtrDelta, Depth) && + areConsecutivePointers(SelectA->getFalseValue(), + SelectB->getFalseValue(), PtrDelta, Depth); + } + } + return false; +} + void Vectorizer::reorder(Instruction *I) { OrderedBasicBlock OBB(I->getParent()); SmallPtrSet InstructionsToMove; @@ -656,6 +689,20 @@ return Chain.slice(0, ChainIdx); } +static ChainID getChainID(const Value *Ptr, const DataLayout &DL) { + const Value *ObjPtr = GetUnderlyingObject(Ptr, DL); + if (const auto *Sel = dyn_cast(ObjPtr)) { + // The select's themselves are distinct instructions even if they share the + // same condition and evaluate to consecutive pointers for true and false + // values of the condition. Therefore using the select's themselves for + // grouping instructions would put consecutive accesses into different lists + // and they won't be even checked for being consecutive, and won't be + // vectorized. + return Sel->getCondition(); + } + return ObjPtr; +} + std::pair Vectorizer::collectInstructions(BasicBlock *BB) { InstrListMap LoadRefs; @@ -710,8 +757,8 @@ continue; // Save the load locations. - Value *ObjPtr = GetUnderlyingObject(Ptr, DL); - LoadRefs[ObjPtr].push_back(LI); + const ChainID ID = getChainID(Ptr, DL); + LoadRefs[ID].push_back(LI); } else if (StoreInst *SI = dyn_cast(&I)) { if (!SI->isSimple()) continue; @@ -756,8 +803,8 @@ continue; // Save store location. - Value *ObjPtr = GetUnderlyingObject(Ptr, DL); - StoreRefs[ObjPtr].push_back(SI); + const ChainID ID = getChainID(Ptr, DL); + StoreRefs[ID].push_back(SI); } } @@ -767,7 +814,7 @@ bool Vectorizer::vectorizeChains(InstrListMap &Map) { bool Changed = false; - for (const std::pair &Chain : Map) { + for (const std::pair &Chain : Map) { unsigned Size = Chain.second.size(); if (Size < 2) continue; Index: llvm/trunk/test/Transforms/LoadStoreVectorizer/AMDGPU/selects.ll =================================================================== --- llvm/trunk/test/Transforms/LoadStoreVectorizer/AMDGPU/selects.ll +++ llvm/trunk/test/Transforms/LoadStoreVectorizer/AMDGPU/selects.ll @@ -0,0 +1,95 @@ +; RUN: opt -mtriple=amdgcn-amd-amdhsa -load-store-vectorizer -dce -S -o - %s | FileCheck %s + +target datalayout = "e-p:32:32-p1:64:64-p2:64:64-p3:32:32-p4:64:64-p5:32:32-p24:64:64-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64" + +define void @base_case(i1 %cnd, i32 addrspace(1)* %a, i32 addrspace(1)* %b, <3 x i32> addrspace(1)* %out) { +; CHECK-LABEL: @base_case +; CHECK: load <3 x i32> +entry: + %gep1 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 1 + %gep2 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 2 + %gep4 = getelementptr inbounds i32, i32 addrspace(1)* %b, i64 1 + %gep5 = getelementptr inbounds i32, i32 addrspace(1)* %b, i64 2 + %selected = select i1 %cnd, i32 addrspace(1)* %a, i32 addrspace(1)* %b + %selected14 = select i1 %cnd, i32 addrspace(1)* %gep1, i32 addrspace(1)* %gep4 + %selected25 = select i1 %cnd, i32 addrspace(1)* %gep2, i32 addrspace(1)* %gep5 + %val0 = load i32, i32 addrspace(1)* %selected, align 4 + %val1 = load i32, i32 addrspace(1)* %selected14, align 4 + %val2 = load i32, i32 addrspace(1)* %selected25, align 4 + %t0 = insertelement <3 x i32> undef, i32 %val0, i32 0 + %t1 = insertelement <3 x i32> %t0, i32 %val1, i32 1 + %t2 = insertelement <3 x i32> %t1, i32 %val2, i32 2 + store <3 x i32> %t2, <3 x i32> addrspace(1)* %out + ret void +} + +define void @scev_targeting_complex_case(i1 %cnd, i32 addrspace(1)* %a, i32 addrspace(1)* %b, i32 %base, <2 x i32> addrspace(1)* %out) { +; CHECK-LABEL: @scev_targeting_complex_case +; CHECK: load <2 x i32> +entry: + %base.x4 = shl i32 %base, 2 + %base.x4.p1 = add i32 %base.x4, 1 + %base.x4.p2 = add i32 %base.x4, 2 + %base.x4.p3 = add i32 %base.x4, 3 + %zext.x4 = zext i32 %base.x4 to i64 + %zext.x4.p1 = zext i32 %base.x4.p1 to i64 + %zext.x4.p2 = zext i32 %base.x4.p2 to i64 + %zext.x4.p3 = zext i32 %base.x4.p3 to i64 + %base.x16 = mul i64 %zext.x4, 4 + %base.x16.p4 = shl i64 %zext.x4.p1, 2 + %base.x16.p8 = shl i64 %zext.x4.p2, 2 + %base.x16.p12 = mul i64 %zext.x4.p3, 4 + %a.pi8 = bitcast i32 addrspace(1)* %a to i8 addrspace(1)* + %b.pi8 = bitcast i32 addrspace(1)* %b to i8 addrspace(1)* + %gep.a.base.x16 = getelementptr inbounds i8, i8 addrspace(1)* %a.pi8, i64 %base.x16 + %gep.b.base.x16.p4 = getelementptr inbounds i8, i8 addrspace(1)* %b.pi8, i64 %base.x16.p4 + %gep.a.base.x16.p8 = getelementptr inbounds i8, i8 addrspace(1)* %a.pi8, i64 %base.x16.p8 + %gep.b.base.x16.p12 = getelementptr inbounds i8, i8 addrspace(1)* %b.pi8, i64 %base.x16.p12 + %a.base.x16 = bitcast i8 addrspace(1)* %gep.a.base.x16 to i32 addrspace(1)* + %b.base.x16.p4 = bitcast i8 addrspace(1)* %gep.b.base.x16.p4 to i32 addrspace(1)* + %selected.base.x16.p0.or.4 = select i1 %cnd, i32 addrspace(1)* %a.base.x16, i32 addrspace(1)* %b.base.x16.p4 + %gep.selected.base.x16.p8.or.12 = select i1 %cnd, i8 addrspace(1)* %gep.a.base.x16.p8, i8 addrspace(1)* %gep.b.base.x16.p12 + %selected.base.x16.p8.or.12 = bitcast i8 addrspace(1)* %gep.selected.base.x16.p8.or.12 to i32 addrspace(1)* + %selected.base.x16.p40.or.44 = getelementptr inbounds i32, i32 addrspace(1)* %selected.base.x16.p0.or.4, i64 10 + %selected.base.x16.p44.or.48 = getelementptr inbounds i32, i32 addrspace(1)* %selected.base.x16.p8.or.12, i64 9 + %val0 = load i32, i32 addrspace(1)* %selected.base.x16.p40.or.44, align 4 + %val1 = load i32, i32 addrspace(1)* %selected.base.x16.p44.or.48, align 4 + %t0 = insertelement <2 x i32> undef, i32 %val0, i32 0 + %t1 = insertelement <2 x i32> %t0, i32 %val1, i32 1 + store <2 x i32> %t1, <2 x i32> addrspace(1)* %out + ret void +} + +define void @nested_selects(i1 %cnd0, i1 %cnd1, i32 addrspace(1)* %a, i32 addrspace(1)* %b, i32 %base, <2 x i32> addrspace(1)* %out) { +; CHECK-LABEL: @nested_selects +; CHECK: load <2 x i32> +entry: + %base.p1 = add nsw i32 %base, 1 + %base.p2 = add i32 %base, 2 + %base.p3 = add nsw i32 %base, 3 + %base.x4 = mul i32 %base, 4 + %base.x4.p5 = add i32 %base.x4, 5 + %base.x4.p6 = add i32 %base.x4, 6 + %sext = sext i32 %base to i64 + %sext.p1 = sext i32 %base.p1 to i64 + %sext.p2 = sext i32 %base.p2 to i64 + %sext.p3 = sext i32 %base.p3 to i64 + %sext.x4.p5 = sext i32 %base.x4.p5 to i64 + %sext.x4.p6 = sext i32 %base.x4.p6 to i64 + %gep.a.base = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext + %gep.a.base.p1 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.p1 + %gep.a.base.p2 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.p2 + %gep.a.base.p3 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.p3 + %gep.b.base.x4.p5 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.x4.p5 + %gep.b.base.x4.p6 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.x4.p6 + %selected.1.L = select i1 %cnd1, i32 addrspace(1)* %gep.a.base.p2, i32 addrspace(1)* %gep.b.base.x4.p5 + %selected.1.R = select i1 %cnd1, i32 addrspace(1)* %gep.a.base.p3, i32 addrspace(1)* %gep.b.base.x4.p6 + %selected.0.L = select i1 %cnd0, i32 addrspace(1)* %gep.a.base, i32 addrspace(1)* %selected.1.L + %selected.0.R = select i1 %cnd0, i32 addrspace(1)* %gep.a.base.p1, i32 addrspace(1)* %selected.1.R + %val0 = load i32, i32 addrspace(1)* %selected.0.L, align 4 + %val1 = load i32, i32 addrspace(1)* %selected.0.R, align 4 + %t0 = insertelement <2 x i32> undef, i32 %val0, i32 0 + %t1 = insertelement <2 x i32> %t0, i32 %val1, i32 1 + store <2 x i32> %t1, <2 x i32> addrspace(1)* %out + ret void +}