Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -580,9 +580,10 @@ LoopVectorizationLegality(Loop *L, ScalarEvolution *SE, const DataLayout *DL, DominatorTree *DT, TargetLibraryInfo *TLI, - AliasAnalysis *AA, Function *F) + AliasAnalysis *AA, Function *F, + const TargetTransformInfo *TTI) : NumLoads(0), NumStores(0), NumPredStores(0), TheLoop(L), SE(SE), DL(DL), - DT(DT), TLI(TLI), AA(AA), TheFunction(F), Induction(nullptr), + DT(DT), TLI(TLI), AA(AA), TheFunction(F), TTI(TTI), Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), MaxSafeDepDistBytes(-1U) { } @@ -768,6 +769,21 @@ } SmallPtrSet::iterator strides_end() { return StrideSet.end(); } + /// Returns true if the target machine supports masked store operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedStore(Type *DataType, Value *Ptr) { + return TTI->isLegalMaskedStore(DataType, isConsecutivePtr(Ptr)); + } + /// Returns true if the target machine supports masked load operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { + return TTI->isLegalMaskedLoad(DataType, isConsecutivePtr(Ptr)); + } + /// Returns true if vector representation of the instruction \p I + /// requires mask. + bool isMaskRequired(const Instruction* I) { + return (MaskedOp.count(I) != 0); + } private: /// Check if a single basic block loop is vectorizable. /// At this point we know that this is a loop with a constant trip count @@ -840,6 +856,8 @@ AliasAnalysis *AA; /// Parent function Function *TheFunction; + /// Target Transform Info + const TargetTransformInfo *TTI; // --- vectorization state --- // @@ -871,6 +889,10 @@ ValueToValueMap Strides; SmallPtrSet StrideSet; + + /// While vectorizing these instructions we have to generate a + /// call to the appropriate masked intrinsic + SmallPtrSet MaskedOp; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1373,7 +1395,7 @@ } // Check if it is legal to vectorize the loop. - LoopVectorizationLegality LVL(L, SE, DL, DT, TLI, AA, F); + LoopVectorizationLegality LVL(L, SE, DL, DT, TLI, AA, F, TTI); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints); @@ -1761,7 +1783,8 @@ unsigned ScalarAllocatedSize = DL->getTypeAllocSize(ScalarDataTy); unsigned VectorElementSize = DL->getTypeStoreSize(DataTy)/VF; - if (SI && Legal->blockNeedsPredication(SI->getParent())) + if (SI && Legal->blockNeedsPredication(SI->getParent()) && + !Legal->isMaskRequired(SI)) return scalarizeInstruction(Instr, true); if (ScalarAllocatedSize != VectorElementSize) @@ -1855,8 +1878,24 @@ Value *VecPtr = Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); - StoreInst *NewSI = - Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); + + Instruction *NewSI; + if (Legal->isMaskRequired(SI)) { + Type *I8PtrTy = + Builder.getInt8PtrTy(PartPtr->getType()->getPointerAddressSpace()); + + Value *I8Ptr = Builder.CreateBitCast(PartPtr, I8PtrTy); + + VectorParts Cond = createBlockInMask(SI->getParent()); + SmallVector Ops; + Ops.push_back(I8Ptr); + Ops.push_back(StoredVal[Part]); + Ops.push_back(Builder.getInt32(Alignment)); + Ops.push_back(Cond[Part]); + NewSI = Builder.CreateMaskedStore(Ops); + } + else + NewSI = Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); propagateMetadata(NewSI, SI); } return; @@ -1876,9 +1915,26 @@ PartPtr = Builder.CreateGEP(PartPtr, Builder.getInt32(1 - VF)); } - Value *VecPtr = Builder.CreateBitCast(PartPtr, - DataTy->getPointerTo(AddressSpace)); - LoadInst *NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); + Instruction* NewLI; + if (Legal->isMaskRequired(LI)) { + Type *I8PtrTy = + Builder.getInt8PtrTy(PartPtr->getType()->getPointerAddressSpace()); + + Value *I8Ptr = Builder.CreateBitCast(PartPtr, I8PtrTy); + + VectorParts SrcMask = createBlockInMask(LI->getParent()); + SmallVector Ops; + Ops.push_back(I8Ptr); + Ops.push_back(UndefValue::get(DataTy)); + Ops.push_back(Builder.getInt32(Alignment)); + Ops.push_back(SrcMask[Part]); + NewLI = Builder.CreateMaskedLoad(Ops); + } + else { + Value *VecPtr = Builder.CreateBitCast(PartPtr, + DataTy->getPointerTo(AddressSpace)); + NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); + } propagateMetadata(NewLI, LI); Entry[Part] = Reverse ? reverseVector(NewLI) : NewLI; } @@ -5305,12 +5361,27 @@ bool LoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB, SmallPtrSetImpl &SafePtrs) { + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + // Check that we don't have a constant expression that can trap as operand. + for (Instruction::op_iterator OI = it->op_begin(), OE = it->op_end(); + OI != OE; ++OI) { + if (Constant *C = dyn_cast(*OI)) + if (C->canTrap()) + return false; + } // We might be able to hoist the load. if (it->mayReadFromMemory()) { LoadInst *LI = dyn_cast(it); - if (!LI || !SafePtrs.count(LI->getPointerOperand())) + if (!LI) + return false; + if (!SafePtrs.count(LI->getPointerOperand())) { + if (isLegalMaskedLoad(LI->getType(), LI->getPointerOperand())) { + MaskedOp.insert(LI); + continue; + } return false; + } } // We don't predicate stores at the moment. @@ -5318,22 +5389,30 @@ StoreInst *SI = dyn_cast(it); // We only support predication of stores in basic blocks with one // predecessor. - if (!SI || ++NumPredStores > NumberOfStoresToPredicate || - !SafePtrs.count(SI->getPointerOperand()) || - !SI->getParent()->getSinglePredecessor()) + if (!SI) + return false; + + bool isSafePtr = (SafePtrs.count(SI->getPointerOperand()) != 0); + bool isSinglePredecessor = SI->getParent()->getSinglePredecessor(); + + if (++NumPredStores > NumberOfStoresToPredicate || !isSafePtr || + !isSinglePredecessor) { + // Build a masked store if it is legal for the target, otherwise scalarize + // the block. + bool isLegalMaskedOp = + isLegalMaskedStore(SI->getValueOperand()->getType(), + SI->getPointerOperand()); + if (isLegalMaskedOp) { + --NumPredStores; + MaskedOp.insert(SI); + continue; + } return false; + } } if (it->mayThrow()) return false; - // Check that we don't have a constant expression that can trap as operand. - for (Instruction::op_iterator OI = it->op_begin(), OE = it->op_end(); - OI != OE; ++OI) { - if (Constant *C = dyn_cast(*OI)) - if (C->canTrap()) - return false; - } - // The instructions below can trap. switch (it->getOpcode()) { default: continue; Index: test/Transforms/LoopVectorize/X86/masked_load_store.ll =================================================================== --- test/Transforms/LoopVectorize/X86/masked_load_store.ll +++ test/Transforms/LoopVectorize/X86/masked_load_store.ll @@ -0,0 +1,420 @@ +; RUN: opt < %s -O3 -mcpu=corei7-avx -S | FileCheck %s -check-prefix=AVX1 +; RUN: opt < %s -O3 -mcpu=core-avx2 -S | FileCheck %s -check-prefix=AVX2 +; RUN: opt < %s -O3 -mcpu=knl -S | FileCheck %s -check-prefix=AVX512 + +;AVX1-NOT: llvm.masked + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-pc_linux" + +; The source code: +; +;void foo1(int *A, int *B, int *trigger) { +; +; for (int i=0; i<10000; i++) { +; if (trigger[i] < 100) { +; A[i] = B[i] + trigger[i]; +; } +; } +;} + +;AVX2-LABEL: @foo1 +;AVX2: icmp slt <8 x i32> %wide.load, @llvm.masked.load.v8i32 +;AVX2: add nsw <8 x i32> +;AVX2: call void @llvm.masked.store.v8i32 +;AVX2: ret void + +;AVX512-LABEL: @foo1 +;AVX512: icmp slt <16 x i32> %wide.load, @llvm.masked.load.v16i32 +;AVX512: add nsw <16 x i32> +;AVX512: call void @llvm.masked.store.v16i32 +;AVX512: ret void + +; Function Attrs: nounwind uwtable +define void @foo1(i32* %A, i32* %B, i32* %trigger) { +entry: + %A.addr = alloca i32*, align 8 + %B.addr = alloca i32*, align 8 + %trigger.addr = alloca i32*, align 8 + %i = alloca i32, align 4 + store i32* %A, i32** %A.addr, align 8 + store i32* %B, i32** %B.addr, align 8 + store i32* %trigger, i32** %trigger.addr, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %0 = load i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %1 = load i32* %i, align 4 + %idxprom = sext i32 %1 to i64 + %2 = load i32** %trigger.addr, align 8 + %arrayidx = getelementptr inbounds i32* %2, i64 %idxprom + %3 = load i32* %arrayidx, align 4 + %cmp1 = icmp slt i32 %3, 100 + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + %4 = load i32* %i, align 4 + %idxprom2 = sext i32 %4 to i64 + %5 = load i32** %B.addr, align 8 + %arrayidx3 = getelementptr inbounds i32* %5, i64 %idxprom2 + %6 = load i32* %arrayidx3, align 4 + %7 = load i32* %i, align 4 + %idxprom4 = sext i32 %7 to i64 + %8 = load i32** %trigger.addr, align 8 + %arrayidx5 = getelementptr inbounds i32* %8, i64 %idxprom4 + %9 = load i32* %arrayidx5, align 4 + %add = add nsw i32 %6, %9 + %10 = load i32* %i, align 4 + %idxprom6 = sext i32 %10 to i64 + %11 = load i32** %A.addr, align 8 + %arrayidx7 = getelementptr inbounds i32* %11, i64 %idxprom6 + store i32 %add, i32* %arrayidx7, align 4 + br label %if.end + +if.end: ; preds = %if.then, %for.body + br label %for.inc + +for.inc: ; preds = %if.end + %12 = load i32* %i, align 4 + %inc = add nsw i32 %12, 1 + store i32 %inc, i32* %i, align 4 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +; The source code: +; +;void foo2(float *A, float *B, int *trigger) { +; +; for (int i=0; i<10000; i++) { +; if (trigger[i] < 100) { +; A[i] = B[i] + trigger[i]; +; } +; } +;} + +;AVX2-LABEL: @foo2 +;AVX2: icmp slt <8 x i32> %wide.load, @llvm.masked.load.v8f32 +;AVX2: fadd <8 x float> +;AVX2: call void @llvm.masked.store.v8f32 +;AVX2: ret void + +;AVX512-LABEL: @foo2 +;AVX512: icmp slt <16 x i32> %wide.load, @llvm.masked.load.v16f32 +;AVX512: fadd <16 x float> +;AVX512: call void @llvm.masked.store.v16f32 +;AVX512: ret void + +; Function Attrs: nounwind uwtable +define void @foo2(float* %A, float* %B, i32* %trigger) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %trigger.addr = alloca i32*, align 8 + %i = alloca i32, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store i32* %trigger, i32** %trigger.addr, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %0 = load i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %1 = load i32* %i, align 4 + %idxprom = sext i32 %1 to i64 + %2 = load i32** %trigger.addr, align 8 + %arrayidx = getelementptr inbounds i32* %2, i64 %idxprom + %3 = load i32* %arrayidx, align 4 + %cmp1 = icmp slt i32 %3, 100 + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + %4 = load i32* %i, align 4 + %idxprom2 = sext i32 %4 to i64 + %5 = load float** %B.addr, align 8 + %arrayidx3 = getelementptr inbounds float* %5, i64 %idxprom2 + %6 = load float* %arrayidx3, align 4 + %7 = load i32* %i, align 4 + %idxprom4 = sext i32 %7 to i64 + %8 = load i32** %trigger.addr, align 8 + %arrayidx5 = getelementptr inbounds i32* %8, i64 %idxprom4 + %9 = load i32* %arrayidx5, align 4 + %conv = sitofp i32 %9 to float + %add = fadd float %6, %conv + %10 = load i32* %i, align 4 + %idxprom6 = sext i32 %10 to i64 + %11 = load float** %A.addr, align 8 + %arrayidx7 = getelementptr inbounds float* %11, i64 %idxprom6 + store float %add, float* %arrayidx7, align 4 + br label %if.end + +if.end: ; preds = %if.then, %for.body + br label %for.inc + +for.inc: ; preds = %if.end + %12 = load i32* %i, align 4 + %inc = add nsw i32 %12, 1 + store i32 %inc, i32* %i, align 4 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +; The source code: +; +;void foo3(double *A, double *B, int *trigger) { +; +; for (int i=0; i<10000; i++) { +; if (trigger[i] < 100) { +; A[i] = B[i] + trigger[i]; +; } +; } +;} + +;AVX2-LABEL: @foo3 +;AVX2: icmp slt <4 x i32> %wide.load, @llvm.masked.load.v4f64 +;AVX2: sitofp <4 x i32> %wide.load to <4 x double> +;AVX2: fadd <4 x double> +;AVX2: call void @llvm.masked.store.v4f64 +;AVX2: ret void + +;AVX512-LABEL: @foo3 +;AVX512: icmp slt <8 x i32> %wide.load, @llvm.masked.load.v8f64 +;AVX512: sitofp <8 x i32> %wide.load to <8 x double> +;AVX512: fadd <8 x double> +;AVX512: call void @llvm.masked.store.v8f64 +;AVX512: ret void + + +; Function Attrs: nounwind uwtable +define void @foo3(double* %A, double* %B, i32* %trigger) #0 { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %trigger.addr = alloca i32*, align 8 + %i = alloca i32, align 4 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store i32* %trigger, i32** %trigger.addr, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %0 = load i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %1 = load i32* %i, align 4 + %idxprom = sext i32 %1 to i64 + %2 = load i32** %trigger.addr, align 8 + %arrayidx = getelementptr inbounds i32* %2, i64 %idxprom + %3 = load i32* %arrayidx, align 4 + %cmp1 = icmp slt i32 %3, 100 + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + %4 = load i32* %i, align 4 + %idxprom2 = sext i32 %4 to i64 + %5 = load double** %B.addr, align 8 + %arrayidx3 = getelementptr inbounds double* %5, i64 %idxprom2 + %6 = load double* %arrayidx3, align 8 + %7 = load i32* %i, align 4 + %idxprom4 = sext i32 %7 to i64 + %8 = load i32** %trigger.addr, align 8 + %arrayidx5 = getelementptr inbounds i32* %8, i64 %idxprom4 + %9 = load i32* %arrayidx5, align 4 + %conv = sitofp i32 %9 to double + %add = fadd double %6, %conv + %10 = load i32* %i, align 4 + %idxprom6 = sext i32 %10 to i64 + %11 = load double** %A.addr, align 8 + %arrayidx7 = getelementptr inbounds double* %11, i64 %idxprom6 + store double %add, double* %arrayidx7, align 8 + br label %if.end + +if.end: ; preds = %if.then, %for.body + br label %for.inc + +for.inc: ; preds = %if.end + %12 = load i32* %i, align 4 + %inc = add nsw i32 %12, 1 + store i32 %inc, i32* %i, align 4 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +; The source code: +; +;void foo4(double *A, double *B, int *trigger) { +; +; for (int i=0; i<10000; i++) { +; if (trigger[i] < 100) { +; A[i] = B[i*2] + trigger[i]; << non-cosecutive access +; } +; } +;} + +;AVX2-LABEL: @foo4 +;AVX2-NOT: llvm.masked +;AVX2: ret void + +;AVX512-LABEL: @foo4 +;AVX512-NOT: llvm.masked +;AVX512: ret void + +; Function Attrs: nounwind uwtable +define void @foo4(double* %A, double* %B, i32* %trigger) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %trigger.addr = alloca i32*, align 8 + %i = alloca i32, align 4 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store i32* %trigger, i32** %trigger.addr, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %0 = load i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %1 = load i32* %i, align 4 + %idxprom = sext i32 %1 to i64 + %2 = load i32** %trigger.addr, align 8 + %arrayidx = getelementptr inbounds i32* %2, i64 %idxprom + %3 = load i32* %arrayidx, align 4 + %cmp1 = icmp slt i32 %3, 100 + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + %4 = load i32* %i, align 4 + %mul = mul nsw i32 %4, 2 + %idxprom2 = sext i32 %mul to i64 + %5 = load double** %B.addr, align 8 + %arrayidx3 = getelementptr inbounds double* %5, i64 %idxprom2 + %6 = load double* %arrayidx3, align 8 + %7 = load i32* %i, align 4 + %idxprom4 = sext i32 %7 to i64 + %8 = load i32** %trigger.addr, align 8 + %arrayidx5 = getelementptr inbounds i32* %8, i64 %idxprom4 + %9 = load i32* %arrayidx5, align 4 + %conv = sitofp i32 %9 to double + %add = fadd double %6, %conv + %10 = load i32* %i, align 4 + %idxprom6 = sext i32 %10 to i64 + %11 = load double** %A.addr, align 8 + %arrayidx7 = getelementptr inbounds double* %11, i64 %idxprom6 + store double %add, double* %arrayidx7, align 8 + br label %if.end + +if.end: ; preds = %if.then, %for.body + br label %for.inc + +for.inc: ; preds = %if.end + %12 = load i32* %i, align 4 + %inc = add nsw i32 %12, 1 + store i32 %inc, i32* %i, align 4 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +@a = common global [1 x i32*] zeroinitializer, align 8 +@c = common global i32* null, align 8 + +; The loop here should not be vectorized due to trapping +; constant expression +;AVX2-LABEL: @foo5 +;AVX2-NOT: llvm.masked +;AVX2: store i32 sdiv +;AVX2: ret void + +;AVX512-LABEL: @foo5 +;AVX512-NOT: llvm.masked +;AVX512: store i32 sdiv +;AVX512: ret void + +; Function Attrs: nounwind uwtable +define void @foo5(i32* %A, i32* %B, i32* %trigger) { +entry: + %A.addr = alloca i32*, align 8 + %B.addr = alloca i32*, align 8 + %trigger.addr = alloca i32*, align 8 + %i = alloca i32, align 4 + store i32* %A, i32** %A.addr, align 8 + store i32* %B, i32** %B.addr, align 8 + store i32* %trigger, i32** %trigger.addr, align 8 + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc, %entry + %0 = load i32* %i, align 4 + %cmp = icmp slt i32 %0, 10000 + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %1 = load i32* %i, align 4 + %idxprom = sext i32 %1 to i64 + %2 = load i32** %trigger.addr, align 8 + %arrayidx = getelementptr inbounds i32* %2, i64 %idxprom + %3 = load i32* %arrayidx, align 4 + %cmp1 = icmp slt i32 %3, 100 + br i1 %cmp1, label %if.then, label %if.end + +if.then: ; preds = %for.body + %4 = load i32* %i, align 4 + %idxprom2 = sext i32 %4 to i64 + %5 = load i32** %B.addr, align 8 + %arrayidx3 = getelementptr inbounds i32* %5, i64 %idxprom2 + %6 = load i32* %arrayidx3, align 4 + %7 = load i32* %i, align 4 + %idxprom4 = sext i32 %7 to i64 + %8 = load i32** %trigger.addr, align 8 + %arrayidx5 = getelementptr inbounds i32* %8, i64 %idxprom4 + %9 = load i32* %arrayidx5, align 4 + %add = add nsw i32 %6, %9 + %10 = load i32* %i, align 4 + %idxprom6 = sext i32 %10 to i64 + %11 = load i32** %A.addr, align 8 + %arrayidx7 = getelementptr inbounds i32* %11, i64 %idxprom6 + store i32 sdiv (i32 1, i32 zext (i1 icmp eq (i32** getelementptr inbounds ([1 x i32*]* @a, i64 0, i64 1), i32** @c) to i32)), i32* %arrayidx7, align 4 + br label %if.end + +if.end: ; preds = %if.then, %for.body + br label %for.inc + +for.inc: ; preds = %if.end + %12 = load i32* %i, align 4 + %inc = add nsw i32 %12, 1 + store i32 %inc, i32* %i, align 4 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} +