Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -274,8 +274,8 @@ /// AVX2 allows masks for consecutive load and store for i32 and i64 elements. /// AVX-512 architecture will also allow masks for non-consecutive memory /// accesses. - virtual bool isLegalPredicatedStore(Type *DataType, int Consecutive) const; - virtual bool isLegalPredicatedLoad (Type *DataType, int Consecutive) const; + virtual bool isLegalMaskedStore(Type *DataType, int Consecutive) const; + virtual bool isLegalMaskedLoad (Type *DataType, int Consecutive) const; /// \brief Return the cost of the scaling factor used in the addressing /// mode represented by AM for this target, for a load/store Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -101,12 +101,12 @@ return PrevTTI->isLegalICmpImmediate(Imm); } -bool TargetTransformInfo::isLegalPredicatedLoad(Type *DataType, +bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, int Consecutive) const { return false; } -bool TargetTransformInfo::isLegalPredicatedStore(Type *DataType, +bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, int Consecutive) const { return false; } Index: lib/Target/X86/X86TargetTransformInfo.cpp =================================================================== --- lib/Target/X86/X86TargetTransformInfo.cpp +++ lib/Target/X86/X86TargetTransformInfo.cpp @@ -111,8 +111,8 @@ Type *Ty) const override; unsigned getIntImmCost(Intrinsic::ID IID, unsigned Idx, const APInt &Imm, Type *Ty) const override; - bool isLegalPredicatedLoad (Type *DataType, int Consecutive) const override; - bool isLegalPredicatedStore(Type *DataType, int Consecutive) const override; + bool isLegalMaskedLoad (Type *DataType, int Consecutive) const override; + bool isLegalMaskedStore(Type *DataType, int Consecutive) const override; /// @} }; @@ -1159,7 +1159,7 @@ return X86TTI::getIntImmCost(Imm, Ty); } -bool X86TTI::isLegalPredicatedLoad(Type *DataType, int Consecutive) const { +bool X86TTI::isLegalMaskedLoad(Type *DataType, int Consecutive) const { int ScalarWidth = DataType->getScalarSizeInBits(); // Todo: AVX512 allows gather/scatter, works with strided and random as well @@ -1170,7 +1170,7 @@ return false; } -bool X86TTI::isLegalPredicatedStore(Type *DataType, int Consecutive) const { - return isLegalPredicatedLoad(DataType, Consecutive); +bool X86TTI::isLegalMaskedStore(Type *DataType, int Consecutive) const { + return isLegalMaskedLoad(DataType, Consecutive); } Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -580,9 +580,10 @@ LoopVectorizationLegality(Loop *L, ScalarEvolution *SE, const DataLayout *DL, DominatorTree *DT, TargetLibraryInfo *TLI, - AliasAnalysis *AA, Function *F) + AliasAnalysis *AA, Function *F, + const TargetTransformInfo *TTI) : NumLoads(0), NumStores(0), NumPredStores(0), TheLoop(L), SE(SE), DL(DL), - DT(DT), TLI(TLI), AA(AA), TheFunction(F), Induction(nullptr), + DT(DT), TLI(TLI), AA(AA), TheFunction(F), TTI(TTI), Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), MaxSafeDepDistBytes(-1U) { } @@ -768,6 +769,22 @@ } SmallPtrSet::iterator strides_end() { return StrideSet.end(); } + /// Returns true if the target machine supports masked store operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedStore(Type *DataType, Value *Ptr) { + return TTI->isLegalMaskedStore(DataType, isConsecutivePtr(Ptr)); + } + /// Returns true if the target machine supports masked load operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { + return TTI->isLegalMaskedLoad(DataType, isConsecutivePtr(Ptr)); + } + /// Returns true if vector representation of the instruction \p I + /// requires mask. + bool toBuildMaskedVectorInst(const Instruction* I) { + return (MaskedOp.count(I) != 0); + } + void SI(); private: /// Check if a single basic block loop is vectorizable. /// At this point we know that this is a loop with a constant trip count @@ -840,6 +857,8 @@ AliasAnalysis *AA; /// Parent function Function *TheFunction; + /// Target Transform Info + const TargetTransformInfo *TTI; // --- vectorization state --- // @@ -871,6 +890,10 @@ ValueToValueMap Strides; SmallPtrSet StrideSet; + + /// While vectorizing these instructions we have to generate a + /// call to the appropriate masked intrinsic + SmallPtrSet MaskedOp; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1373,7 +1396,7 @@ } // Check if it is legal to vectorize the loop. - LoopVectorizationLegality LVL(L, SE, DL, DT, TLI, AA, F); + LoopVectorizationLegality LVL(L, SE, DL, DT, TLI, AA, F, TTI); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints); @@ -1761,7 +1784,8 @@ unsigned ScalarAllocatedSize = DL->getTypeAllocSize(ScalarDataTy); unsigned VectorElementSize = DL->getTypeStoreSize(DataTy)/VF; - if (SI && Legal->blockNeedsPredication(SI->getParent())) + if (SI && Legal->blockNeedsPredication(SI->getParent()) && + !Legal->toBuildMaskedVectorInst(SI)) return scalarizeInstruction(Instr, true); if (ScalarAllocatedSize != VectorElementSize) @@ -1855,8 +1879,24 @@ Value *VecPtr = Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); - StoreInst *NewSI = - Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); + + Instruction *NewSI; + if (Legal->toBuildMaskedVectorInst(SI)) { + Type *I8PtrTy = + Builder.getInt8PtrTy(PartPtr->getType()->getPointerAddressSpace()); + + Value *I8Ptr = Builder.CreateBitCast(PartPtr, I8PtrTy); + + VectorParts Cond = createBlockInMask(SI->getParent()); + SmallVector Ops; + Ops.push_back(I8Ptr); + Ops.push_back(StoredVal[Part]); + Ops.push_back(Builder.getInt32(Alignment)); + Ops.push_back(Cond[Part]); + NewSI = Builder.CreateMaskedStore(Ops); + } + else + NewSI = Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); propagateMetadata(NewSI, SI); } return; @@ -1871,14 +1911,31 @@ if (Reverse) { // If the address is consecutive but reversed, then the - // wide store needs to start at the last vector element. + // wide load needs to start at the last vector element. PartPtr = Builder.CreateGEP(Ptr, Builder.getInt32(-Part * VF)); PartPtr = Builder.CreateGEP(PartPtr, Builder.getInt32(1 - VF)); } - Value *VecPtr = Builder.CreateBitCast(PartPtr, - DataTy->getPointerTo(AddressSpace)); - LoadInst *NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); + Instruction* NewLI; + if (Legal->toBuildMaskedVectorInst(LI)) { + Type *I8PtrTy = + Builder.getInt8PtrTy(PartPtr->getType()->getPointerAddressSpace()); + + Value *I8Ptr = Builder.CreateBitCast(PartPtr, I8PtrTy); + + VectorParts SrcMask = createBlockInMask(LI->getParent()); + SmallVector Ops; + Ops.push_back(I8Ptr); + Ops.push_back(UndefValue::get(DataTy)); + Ops.push_back(Builder.getInt32(Alignment)); + Ops.push_back(SrcMask[Part]); + NewLI = Builder.CreateMaskedLoad(Ops); + } + else { + Value *VecPtr = Builder.CreateBitCast(PartPtr, + DataTy->getPointerTo(AddressSpace)); + NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); + } propagateMetadata(NewLI, LI); Entry[Part] = Reverse ? reverseVector(NewLI) : NewLI; } @@ -5311,8 +5368,15 @@ // We might be able to hoist the load. if (it->mayReadFromMemory()) { LoadInst *LI = dyn_cast(it); - if (!LI || !SafePtrs.count(LI->getPointerOperand())) + if (!LI) return false; + if (!SafePtrs.count(LI->getPointerOperand())) { + if (isLegalMaskedLoad(LI->getType(), LI->getPointerOperand())) { + MaskedOp.insert(LI); + continue; + } + return false; + } } // We don't predicate stores at the moment. @@ -5320,10 +5384,26 @@ StoreInst *SI = dyn_cast(it); // We only support predication of stores in basic blocks with one // predecessor. - if (!SI || ++NumPredStores > NumberOfStoresToPredicate || - !SafePtrs.count(SI->getPointerOperand()) || - !SI->getParent()->getSinglePredecessor()) + if (!SI) + return false; + + bool isSafePtr = (SafePtrs.count(SI->getPointerOperand()) != 0); + bool isSinglePredecessor = SI->getParent()->getSinglePredecessor(); + + if (++NumPredStores > NumberOfStoresToPredicate || !isSafePtr || + !isSinglePredecessor) { + // Build a masked store if it is legal for the target, otherwise scalarize + // the block. + bool isLegalMaskedOp = + isLegalMaskedStore(SI->getValueOperand()->getType(), + SI->getPointerOperand()); + if (isLegalMaskedOp) { + --NumPredStores; + MaskedOp.insert(SI); + continue; + } return false; + } } if (it->mayThrow()) return false; @@ -5387,7 +5467,7 @@ MaxVectorSize = 1; } - assert(MaxVectorSize <= 32 && "Did not expect to pack so many elements" + assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements" " into one vector!"); unsigned VF = MaxVectorSize; Index: test/Transforms/LoopVectorize/X86/masked_load_store.ll =================================================================== --- test/Transforms/LoopVectorize/X86/masked_load_store.ll +++ test/Transforms/LoopVectorize/X86/masked_load_store.ll @@ -0,0 +1,347 @@ +; RUN: opt < %s -O3 -mcpu=corei7-avx -S | FileCheck %s -check-prefix=AVX1 +; RUN: opt < %s -O3 -mcpu=core-avx2 -S | FileCheck %s -check-prefix=AVX2 +; RUN: opt < %s -O3 -mcpu=knl -S | FileCheck %s -check-prefix=AVX512 + +;AVX1-NOT: llvm.masked + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-pc_linux" + +; The source code: +; +;void foo1(int *A, int *B, int *trigger) { +; +; for (int i=0; i<10000; i++) { +; if (trigger[i] < 100) { +; A[i] = B[i] + trigger[i]; +; } +; } +;} + +;AVX2-LABEL: @foo1 +;AVX2: icmp slt <8 x i32> %wide.load, @llvm.masked.load.v8i32 +;AVX2: add nsw <8 x i32> +;AVX2: call void @llvm.masked.store.v8i32 +;AVX2: ret void + +;AVX512-LABEL: @foo1 +;AVX512: icmp slt <16 x i32> %wide.load, @llvm.masked.load.v16i32 +;AVX512: add nsw <16 x i32> +;AVX512: call void @llvm.masked.store.v16i32 +;AVX512: ret void + +; Function Attrs: nounwind uwtable +define void @foo1(i32* %A, i32* %B, i32* %trigger) { +entry: + %A.addr = alloca i32*, align 8 + %B.addr = alloca i32*, align 8 + %trigger.addr = alloca i32*, align 8 + %i = alloca i32, align 4 + store i32* %A, i32** %A.addr, align 8 + store i32* %B, i32** %B.addr, align 8 + store i32* %trigger, i32** %trigger.addr, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %0 = load i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %1 = load i32* %i, align 4 + %idxprom = sext i32 %1 to i64 + %2 = load i32** %trigger.addr, align 8 + %arrayidx = getelementptr inbounds i32* %2, i64 %idxprom + %3 = load i32* %arrayidx, align 4 + %cmp1 = icmp slt i32 %3, 100 + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + %4 = load i32* %i, align 4 + %idxprom2 = sext i32 %4 to i64 + %5 = load i32** %B.addr, align 8 + %arrayidx3 = getelementptr inbounds i32* %5, i64 %idxprom2 + %6 = load i32* %arrayidx3, align 4 + %7 = load i32* %i, align 4 + %idxprom4 = sext i32 %7 to i64 + %8 = load i32** %trigger.addr, align 8 + %arrayidx5 = getelementptr inbounds i32* %8, i64 %idxprom4 + %9 = load i32* %arrayidx5, align 4 + %add = add nsw i32 %6, %9 + %10 = load i32* %i, align 4 + %idxprom6 = sext i32 %10 to i64 + %11 = load i32** %A.addr, align 8 + %arrayidx7 = getelementptr inbounds i32* %11, i64 %idxprom6 + store i32 %add, i32* %arrayidx7, align 4 + br label %if.end + +if.end: ; preds = %if.then, %for.body + br label %for.inc + +for.inc: ; preds = %if.end + %12 = load i32* %i, align 4 + %inc = add nsw i32 %12, 1 + store i32 %inc, i32* %i, align 4 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +; The source code: +; +;void foo2(float *A, float *B, int *trigger) { +; +; for (int i=0; i<10000; i++) { +; if (trigger[i] < 100) { +; A[i] = B[i] + trigger[i]; +; } +; } +;} + +;AVX2-LABEL: @foo2 +;AVX2: icmp slt <8 x i32> %wide.load, @llvm.masked.load.v8f32 +;AVX2: fadd <8 x float> +;AVX2: call void @llvm.masked.store.v8f32 +;AVX2: ret void + +;AVX512-LABEL: @foo2 +;AVX512: icmp slt <16 x i32> %wide.load, @llvm.masked.load.v16f32 +;AVX512: fadd <16 x float> +;AVX512: call void @llvm.masked.store.v16f32 +;AVX512: ret void + +; Function Attrs: nounwind uwtable +define void @foo2(float* %A, float* %B, i32* %trigger) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %trigger.addr = alloca i32*, align 8 + %i = alloca i32, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store i32* %trigger, i32** %trigger.addr, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %0 = load i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %1 = load i32* %i, align 4 + %idxprom = sext i32 %1 to i64 + %2 = load i32** %trigger.addr, align 8 + %arrayidx = getelementptr inbounds i32* %2, i64 %idxprom + %3 = load i32* %arrayidx, align 4 + %cmp1 = icmp slt i32 %3, 100 + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + %4 = load i32* %i, align 4 + %idxprom2 = sext i32 %4 to i64 + %5 = load float** %B.addr, align 8 + %arrayidx3 = getelementptr inbounds float* %5, i64 %idxprom2 + %6 = load float* %arrayidx3, align 4 + %7 = load i32* %i, align 4 + %idxprom4 = sext i32 %7 to i64 + %8 = load i32** %trigger.addr, align 8 + %arrayidx5 = getelementptr inbounds i32* %8, i64 %idxprom4 + %9 = load i32* %arrayidx5, align 4 + %conv = sitofp i32 %9 to float + %add = fadd float %6, %conv + %10 = load i32* %i, align 4 + %idxprom6 = sext i32 %10 to i64 + %11 = load float** %A.addr, align 8 + %arrayidx7 = getelementptr inbounds float* %11, i64 %idxprom6 + store float %add, float* %arrayidx7, align 4 + br label %if.end + +if.end: ; preds = %if.then, %for.body + br label %for.inc + +for.inc: ; preds = %if.end + %12 = load i32* %i, align 4 + %inc = add nsw i32 %12, 1 + store i32 %inc, i32* %i, align 4 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +; The source code: +; +;void foo3(double *A, double *B, int *trigger) { +; +; for (int i=0; i<10000; i++) { +; if (trigger[i] < 100) { +; A[i] = B[i] + trigger[i]; +; } +; } +;} + +;AVX2-LABEL: @foo3 +;AVX2: icmp slt <4 x i32> %wide.load, @llvm.masked.load.v4f64 +;AVX2: sitofp <4 x i32> %wide.load to <4 x double> +;AVX2: fadd <4 x double> +;AVX2: call void @llvm.masked.store.v4f64 +;AVX2: ret void + +;AVX512-LABEL: @foo3 +;AVX512: icmp slt <8 x i32> %wide.load, @llvm.masked.load.v8f64 +;AVX512: sitofp <8 x i32> %wide.load to <8 x double> +;AVX512: fadd <8 x double> +;AVX512: call void @llvm.masked.store.v8f64 +;AVX512: ret void + + +; Function Attrs: nounwind uwtable +define void @foo3(double* %A, double* %B, i32* %trigger) #0 { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %trigger.addr = alloca i32*, align 8 + %i = alloca i32, align 4 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store i32* %trigger, i32** %trigger.addr, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %0 = load i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %1 = load i32* %i, align 4 + %idxprom = sext i32 %1 to i64 + %2 = load i32** %trigger.addr, align 8 + %arrayidx = getelementptr inbounds i32* %2, i64 %idxprom + %3 = load i32* %arrayidx, align 4 + %cmp1 = icmp slt i32 %3, 100 + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + %4 = load i32* %i, align 4 + %idxprom2 = sext i32 %4 to i64 + %5 = load double** %B.addr, align 8 + %arrayidx3 = getelementptr inbounds double* %5, i64 %idxprom2 + %6 = load double* %arrayidx3, align 8 + %7 = load i32* %i, align 4 + %idxprom4 = sext i32 %7 to i64 + %8 = load i32** %trigger.addr, align 8 + %arrayidx5 = getelementptr inbounds i32* %8, i64 %idxprom4 + %9 = load i32* %arrayidx5, align 4 + %conv = sitofp i32 %9 to double + %add = fadd double %6, %conv + %10 = load i32* %i, align 4 + %idxprom6 = sext i32 %10 to i64 + %11 = load double** %A.addr, align 8 + %arrayidx7 = getelementptr inbounds double* %11, i64 %idxprom6 + store double %add, double* %arrayidx7, align 8 + br label %if.end + +if.end: ; preds = %if.then, %for.body + br label %for.inc + +for.inc: ; preds = %if.end + %12 = load i32* %i, align 4 + %inc = add nsw i32 %12, 1 + store i32 %inc, i32* %i, align 4 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +; The source code: +; +;void foo4(double *A, double *B, int *trigger) { +; +; for (int i=0; i<10000; i++) { +; if (trigger[i] < 100) { +; A[i] = B[i*2] + trigger[i]; << non-cosecutive access +; } +; } +;} + +;AVX2-LABEL: @foo4 +;AVX2-NOT: llvm.masked +;AVX2: ret void + +;AVX512-LABEL: @foo4 +;AVX512-NOT: llvm.masked +;AVX512: ret void + +; Function Attrs: nounwind uwtable +define void @foo4(double* %A, double* %B, i32* %trigger) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %trigger.addr = alloca i32*, align 8 + %i = alloca i32, align 4 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store i32* %trigger, i32** %trigger.addr, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %0 = load i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %1 = load i32* %i, align 4 + %idxprom = sext i32 %1 to i64 + %2 = load i32** %trigger.addr, align 8 + %arrayidx = getelementptr inbounds i32* %2, i64 %idxprom + %3 = load i32* %arrayidx, align 4 + %cmp1 = icmp slt i32 %3, 100 + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + %4 = load i32* %i, align 4 + %mul = mul nsw i32 %4, 2 + %idxprom2 = sext i32 %mul to i64 + %5 = load double** %B.addr, align 8 + %arrayidx3 = getelementptr inbounds double* %5, i64 %idxprom2 + %6 = load double* %arrayidx3, align 8 + %7 = load i32* %i, align 4 + %idxprom4 = sext i32 %7 to i64 + %8 = load i32** %trigger.addr, align 8 + %arrayidx5 = getelementptr inbounds i32* %8, i64 %idxprom4 + %9 = load i32* %arrayidx5, align 4 + %conv = sitofp i32 %9 to double + %add = fadd double %6, %conv + %10 = load i32* %i, align 4 + %idxprom6 = sext i32 %10 to i64 + %11 = load double** %A.addr, align 8 + %arrayidx7 = getelementptr inbounds double* %11, i64 %idxprom6 + store double %add, double* %arrayidx7, align 8 + br label %if.end + +if.end: ; preds = %if.then, %for.body + br label %for.inc + +for.inc: ; preds = %if.end + %12 = load i32* %i, align 4 + %inc = add nsw i32 %12, 1 + store i32 %inc, i32* %i, align 4 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +