Index: llvm/trunk/lib/IR/IRBuilder.cpp =================================================================== --- llvm/trunk/lib/IR/IRBuilder.cpp +++ llvm/trunk/lib/IR/IRBuilder.cpp @@ -356,6 +356,7 @@ PointerType *PtrTy = cast(Ptr->getType()); Type *DataTy = PtrTy->getElementType(); assert(DataTy->isVectorTy() && "Ptr should point to a vector"); + assert(Mask && "Mask should not be all-ones (null)"); if (!PassThru) PassThru = UndefValue::get(DataTy); Type *OverloadedTypes[] = { DataTy, PtrTy }; @@ -375,6 +376,7 @@ PointerType *PtrTy = cast(Ptr->getType()); Type *DataTy = PtrTy->getElementType(); assert(DataTy->isVectorTy() && "Ptr should point to a vector"); + assert(Mask && "Mask should not be all-ones (null)"); Type *OverloadedTypes[] = { DataTy, PtrTy }; Value *Ops[] = { Val, Ptr, getInt32(Align), Mask }; return CreateMaskedIntrinsic(Intrinsic::masked_store, Ops, OverloadedTypes); Index: llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -3045,13 +3045,14 @@ Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(-Part * VF)); PartPtr = Builder.CreateGEP(nullptr, PartPtr, Builder.getInt32(1 - VF)); - Mask[Part] = reverseVector(Mask[Part]); + if (Mask[Part]) // The reverse of a null all-one mask is a null mask. + Mask[Part] = reverseVector(Mask[Part]); } Value *VecPtr = Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); - if (Legal->isMaskRequired(SI)) + if (Legal->isMaskRequired(SI) && Mask[Part]) NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr, Alignment, Mask[Part]); else @@ -3083,12 +3084,13 @@ // wide load needs to start at the last vector element. PartPtr = Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(-Part * VF)); PartPtr = Builder.CreateGEP(nullptr, PartPtr, Builder.getInt32(1 - VF)); - Mask[Part] = reverseVector(Mask[Part]); + if (Mask[Part]) // The reverse of a null all-one mask is a null mask. + Mask[Part] = reverseVector(Mask[Part]); } Value *VecPtr = Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); - if (Legal->isMaskRequired(LI)) + if (Legal->isMaskRequired(LI) && Mask[Part]) NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, Mask[Part], UndefValue::get(DataTy), "wide.masked.load"); @@ -3136,10 +3138,10 @@ Value *Cmp = nullptr; if (IfPredicateInstr) { Cmp = Cond[Part]; - if (Cmp->getType()->isVectorTy()) + if (!Cmp) // Block in mask is all-one. + Cmp = Builder.getTrue(); + else if (Cmp->getType()->isVectorTy()) Cmp = Builder.CreateExtractElement(Cmp, Builder.getInt32(Lane)); - Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cmp, - ConstantInt::get(Cmp->getType(), 1)); } Instruction *Cloned = Instr->clone(); @@ -4518,24 +4520,22 @@ BranchInst *BI = dyn_cast(Src->getTerminator()); assert(BI && "Unexpected terminator found"); - if (BI->isConditional()) { + if (!BI->isConditional()) + return EdgeMaskCache[Edge] = SrcMask; - VectorParts EdgeMask(UF); - for (unsigned Part = 0; Part < UF; ++Part) { - auto *EdgeMaskPart = getOrCreateVectorValue(BI->getCondition(), Part); - if (BI->getSuccessor(0) != Dst) - EdgeMaskPart = Builder.CreateNot(EdgeMaskPart); + VectorParts EdgeMask(UF); + for (unsigned Part = 0; Part < UF; ++Part) { + auto *EdgeMaskPart = getOrCreateVectorValue(BI->getCondition(), Part); + if (BI->getSuccessor(0) != Dst) + EdgeMaskPart = Builder.CreateNot(EdgeMaskPart); + if (SrcMask[Part]) // Otherwise block in-mask is all-one, no need to AND. EdgeMaskPart = Builder.CreateAnd(EdgeMaskPart, SrcMask[Part]); - EdgeMask[Part] = EdgeMaskPart; - } - EdgeMaskCache[Edge] = EdgeMask; - return EdgeMask; + EdgeMask[Part] = EdgeMaskPart; } - EdgeMaskCache[Edge] = SrcMask; - return SrcMask; + return EdgeMaskCache[Edge] = EdgeMask; } InnerLoopVectorizer::VectorParts @@ -4547,31 +4547,32 @@ if (BCEntryIt != BlockMaskCache.end()) return BCEntryIt->second; + // All-one mask is modelled as no-mask following the convention for masked + // load/store/gather/scatter. Initialize BlockMask to no-mask. VectorParts BlockMask(UF); + for (unsigned Part = 0; Part < UF; ++Part) + BlockMask[Part] = nullptr; // Loop incoming mask is all-one. - if (OrigLoop->getHeader() == BB) { - Value *C = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 1); - for (unsigned Part = 0; Part < UF; ++Part) - BlockMask[Part] = getOrCreateVectorValue(C, Part); - BlockMaskCache[BB] = BlockMask; - return BlockMask; - } + if (OrigLoop->getHeader() == BB) + return BlockMaskCache[BB] = BlockMask; - // This is the block mask. We OR all incoming edges, and with zero. - Value *Zero = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 0); - for (unsigned Part = 0; Part < UF; ++Part) - BlockMask[Part] = getOrCreateVectorValue(Zero, Part); + // This is the block mask. We OR all incoming edges. + for (auto *Predecessor : predecessors(BB)) { + VectorParts EdgeMask = createEdgeMask(Predecessor, BB); + if (!EdgeMask[0]) // Mask of predecessor is all-one so mask of block is too. + return BlockMaskCache[BB] = EdgeMask; + + if (!BlockMask[0]) { // BlockMask has its initialized nullptr value. + BlockMask = EdgeMask; + continue; + } - // For each pred: - for (pred_iterator It = pred_begin(BB), E = pred_end(BB); It != E; ++It) { - VectorParts EM = createEdgeMask(*It, BB); for (unsigned Part = 0; Part < UF; ++Part) - BlockMask[Part] = Builder.CreateOr(BlockMask[Part], EM[Part]); + BlockMask[Part] = Builder.CreateOr(BlockMask[Part], EdgeMask[Part]); } - BlockMaskCache[BB] = BlockMask; - return BlockMask; + return BlockMaskCache[BB] = BlockMask; } void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, Index: llvm/trunk/test/Transforms/LoopVectorize/if-pred-non-void.ll =================================================================== --- llvm/trunk/test/Transforms/LoopVectorize/if-pred-non-void.ll +++ llvm/trunk/test/Transforms/LoopVectorize/if-pred-non-void.ll @@ -18,8 +18,7 @@ ; CHECK-LABEL: test ; CHECK: vector.body: ; CHECK: %[[SDEE:[a-zA-Z0-9]+]] = extractelement <2 x i1> %{{.*}}, i32 0 -; CHECK: %[[SDCC:[a-zA-Z0-9]+]] = icmp eq i1 %[[SDEE]], true -; CHECK: br i1 %[[SDCC]], label %[[CSD:[a-zA-Z0-9.]+]], label %[[ESD:[a-zA-Z0-9.]+]] +; CHECK: br i1 %[[SDEE]], label %[[CSD:[a-zA-Z0-9.]+]], label %[[ESD:[a-zA-Z0-9.]+]] ; CHECK: [[CSD]]: ; CHECK: %[[SDA0:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 0 ; CHECK: %[[SDA1:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 0 @@ -29,8 +28,7 @@ ; CHECK: [[ESD]]: ; CHECK: %[[SDR:[a-zA-Z0-9]+]] = phi <2 x i32> [ undef, %vector.body ], [ %[[SD1]], %[[CSD]] ] ; CHECK: %[[SDEEH:[a-zA-Z0-9]+]] = extractelement <2 x i1> %{{.*}}, i32 1 -; CHECK: %[[SDCCH:[a-zA-Z0-9]+]] = icmp eq i1 %[[SDEEH]], true -; CHECK: br i1 %[[SDCCH]], label %[[CSDH:[a-zA-Z0-9.]+]], label %[[ESDH:[a-zA-Z0-9.]+]] +; CHECK: br i1 %[[SDEEH]], label %[[CSDH:[a-zA-Z0-9.]+]], label %[[ESDH:[a-zA-Z0-9.]+]] ; CHECK: [[CSDH]]: ; CHECK: %[[SDA0H:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 1 ; CHECK: %[[SDA1H:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 1 @@ -41,8 +39,7 @@ ; CHECK: %{{.*}} = phi <2 x i32> [ %[[SDR]], %[[ESD]] ], [ %[[SD1H]], %[[CSDH]] ] ; CHECK: %[[UDEE:[a-zA-Z0-9]+]] = extractelement <2 x i1> %{{.*}}, i32 0 -; CHECK: %[[UDCC:[a-zA-Z0-9]+]] = icmp eq i1 %[[UDEE]], true -; CHECK: br i1 %[[UDCC]], label %[[CUD:[a-zA-Z0-9.]+]], label %[[EUD:[a-zA-Z0-9.]+]] +; CHECK: br i1 %[[UDEE]], label %[[CUD:[a-zA-Z0-9.]+]], label %[[EUD:[a-zA-Z0-9.]+]] ; CHECK: [[CUD]]: ; CHECK: %[[UDA0:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 0 ; CHECK: %[[UDA1:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 0 @@ -53,8 +50,7 @@ ; CHECK: %{{.*}} = phi <2 x i32> [ undef, %{{.*}} ], [ %[[UD1]], %[[CUD]] ] ; CHECK: %[[SREE:[a-zA-Z0-9]+]] = extractelement <2 x i1> %{{.*}}, i32 0 -; CHECK: %[[SRCC:[a-zA-Z0-9]+]] = icmp eq i1 %[[SREE]], true -; CHECK: br i1 %[[SRCC]], label %[[CSR:[a-zA-Z0-9.]+]], label %[[ESR:[a-zA-Z0-9.]+]] +; CHECK: br i1 %[[SREE]], label %[[CSR:[a-zA-Z0-9.]+]], label %[[ESR:[a-zA-Z0-9.]+]] ; CHECK: [[CSR]]: ; CHECK: %[[SRA0:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 0 ; CHECK: %[[SRA1:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 0 @@ -65,8 +61,7 @@ ; CHECK: %{{.*}} = phi <2 x i32> [ undef, %{{.*}} ], [ %[[SR1]], %[[CSR]] ] ; CHECK: %[[UREE:[a-zA-Z0-9]+]] = extractelement <2 x i1> %{{.*}}, i32 0 -; CHECK: %[[URCC:[a-zA-Z0-9]+]] = icmp eq i1 %[[UREE]], true -; CHECK: br i1 %[[URCC]], label %[[CUR:[a-zA-Z0-9.]+]], label %[[EUR:[a-zA-Z0-9.]+]] +; CHECK: br i1 %[[UREE]], label %[[CUR:[a-zA-Z0-9.]+]], label %[[EUR:[a-zA-Z0-9.]+]] ; CHECK: [[CUR]]: ; CHECK: %[[URA0:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 0 ; CHECK: %[[URA1:[a-zA-Z0-9]+]] = extractelement <2 x i32> %{{.*}}, i32 0 @@ -164,16 +159,11 @@ ; CHECK: vector.body: ; CHECK: %[[CMP1:.+]] = icmp slt <2 x i32> %[[VAL:.+]], ; CHECK: %[[CMP2:.+]] = icmp sge <2 x i32> %[[VAL]], -; CHECK: %[[XOR:.+]] = xor <2 x i1> %[[CMP1]], -; CHECK: %[[AND1:.+]] = and <2 x i1> %[[XOR]], -; CHECK: %[[OR1:.+]] = or <2 x i1> zeroinitializer, %[[AND1]] -; CHECK: %[[AND2:.+]] = and <2 x i1> %[[CMP2]], %[[OR1]] -; CHECK: %[[OR2:.+]] = or <2 x i1> zeroinitializer, %[[AND2]] -; CHECK: %[[AND3:.+]] = and <2 x i1> %[[CMP1]], -; CHECK: %[[OR3:.+]] = or <2 x i1> %[[OR2]], %[[AND3]] -; CHECK: %[[EXTRACT:.+]] = extractelement <2 x i1> %[[OR3]], i32 0 -; CHECK: %[[MASK:.+]] = icmp eq i1 %[[EXTRACT]], true -; CHECK: br i1 %[[MASK]], label %[[THEN:[a-zA-Z0-9.]+]], label %[[FI:[a-zA-Z0-9.]+]] +; CHECK: %[[NOT:.+]] = xor <2 x i1> %[[CMP1]], +; CHECK: %[[AND:.+]] = and <2 x i1> %[[CMP2]], %[[NOT]] +; CHECK: %[[OR:.+]] = or <2 x i1> %[[AND]], %[[CMP1]] +; CHECK: %[[EXTRACT:.+]] = extractelement <2 x i1> %[[OR]], i32 0 +; CHECK: br i1 %[[EXTRACT]], label %[[THEN:[a-zA-Z0-9.]+]], label %[[FI:[a-zA-Z0-9.]+]] ; CHECK: [[THEN]]: ; CHECK: %[[PD:[a-zA-Z0-9]+]] = sdiv i32 %{{.*}}, %{{.*}} ; CHECK: br label %[[FI]] Index: llvm/trunk/test/Transforms/LoopVectorize/if-pred-stores.ll =================================================================== --- llvm/trunk/test/Transforms/LoopVectorize/if-pred-stores.ll +++ llvm/trunk/test/Transforms/LoopVectorize/if-pred-stores.ll @@ -13,11 +13,8 @@ ; VEC: %[[v0:.+]] = add i64 %index, 0 ; VEC: %[[v2:.+]] = getelementptr inbounds i32, i32* %f, i64 %[[v0]] ; VEC: %[[v8:.+]] = icmp sgt <2 x i32> %{{.*}}, -; VEC: %[[v10:.+]] = and <2 x i1> %[[v8]], -; VEC: %[[o1:.+]] = or <2 x i1> zeroinitializer, %[[v10]] -; VEC: %[[v11:.+]] = extractelement <2 x i1> %[[o1]], i32 0 -; VEC: %[[v12:.+]] = icmp eq i1 %[[v11]], true -; VEC: br i1 %[[v12]], label %[[cond:.+]], label %[[else:.+]] +; VEC: %[[v11:.+]] = extractelement <2 x i1> %[[v8]], i32 0 +; VEC: br i1 %[[v11]], label %[[cond:.+]], label %[[else:.+]] ; ; VEC: [[cond]]: ; VEC: %[[v13:.+]] = extractelement <2 x i32> %wide.load, i32 0 @@ -26,9 +23,8 @@ ; VEC: br label %[[else:.+]] ; ; VEC: [[else]]: -; VEC: %[[v15:.+]] = extractelement <2 x i1> %[[o1]], i32 1 -; VEC: %[[v16:.+]] = icmp eq i1 %[[v15]], true -; VEC: br i1 %[[v16]], label %[[cond2:.+]], label %[[else2:.+]] +; VEC: %[[v15:.+]] = extractelement <2 x i1> %[[v8]], i32 1 +; VEC: br i1 %[[v15]], label %[[cond2:.+]], label %[[else2:.+]] ; ; VEC: [[cond2]]: ; VEC: %[[v17:.+]] = extractelement <2 x i32> %wide.load, i32 1 @@ -50,10 +46,7 @@ ; UNROLL: %[[v3:[a-zA-Z0-9]+]] = load i32, i32* %[[v1]], align 4 ; UNROLL: %[[v4:[a-zA-Z0-9]+]] = icmp sgt i32 %[[v2]], 100 ; UNROLL: %[[v5:[a-zA-Z0-9]+]] = icmp sgt i32 %[[v3]], 100 -; UNROLL: %[[o1:[a-zA-Z0-9]+]] = or i1 false, %[[v4]] -; UNROLL: %[[o2:[a-zA-Z0-9]+]] = or i1 false, %[[v5]] -; UNROLL: %[[v8:[a-zA-Z0-9]+]] = icmp eq i1 %[[o1]], true -; UNROLL: br i1 %[[v8]], label %[[cond:[a-zA-Z0-9.]+]], label %[[else:[a-zA-Z0-9.]+]] +; UNROLL: br i1 %[[v4]], label %[[cond:[a-zA-Z0-9.]+]], label %[[else:[a-zA-Z0-9.]+]] ; ; UNROLL: [[cond]]: ; UNROLL: %[[v6:[a-zA-Z0-9]+]] = add nsw i32 %[[v2]], 20 @@ -61,8 +54,7 @@ ; UNROLL: br label %[[else]] ; ; UNROLL: [[else]]: -; UNROLL: %[[v9:[a-zA-Z0-9]+]] = icmp eq i1 %[[o2]], true -; UNROLL: br i1 %[[v9]], label %[[cond2:[a-zA-Z0-9.]+]], label %[[else2:[a-zA-Z0-9.]+]] +; UNROLL: br i1 %[[v5]], label %[[cond2:[a-zA-Z0-9.]+]], label %[[else2:[a-zA-Z0-9.]+]] ; ; UNROLL: [[cond2]]: ; UNROLL: %[[v7:[a-zA-Z0-9]+]] = add nsw i32 %[[v3]], 20