Index: llvm/trunk/include/llvm/Transforms/Utils/LoopUtils.h =================================================================== --- llvm/trunk/include/llvm/Transforms/Utils/LoopUtils.h +++ llvm/trunk/include/llvm/Transforms/Utils/LoopUtils.h @@ -531,8 +531,10 @@ /// Get the intersection (logical and) of all of the potential IR flags /// of each scalar operation (VL) that will be converted into a vector (I). +/// If OpValue is non-null, we only consider operations similar to OpValue +/// when intersecting. /// Flag set: NSW, NUW, exact, and all of fast-math. -void propagateIRFlags(Value *I, ArrayRef VL); +void propagateIRFlags(Value *I, ArrayRef VL, Value *OpValue = nullptr); } // end namespace llvm Index: llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp +++ llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp @@ -1376,16 +1376,21 @@ } } -void llvm::propagateIRFlags(Value *I, ArrayRef VL) { - if (auto *VecOp = dyn_cast(I)) { - if (auto *I0 = dyn_cast(VL[0])) { - // VecOVp is initialized to the 0th scalar, so start counting from index - // '1'. - VecOp->copyIRFlags(I0); - for (int i = 1, e = VL.size(); i < e; ++i) { - if (auto *Scalar = dyn_cast(VL[i])) - VecOp->andIRFlags(Scalar); - } - } +void llvm::propagateIRFlags(Value *I, ArrayRef VL, Value *OpValue) { + auto *VecOp = dyn_cast(I); + if (!VecOp) + return; + auto *Intersection = (OpValue == nullptr) ? dyn_cast(VL[0]) + : dyn_cast(OpValue); + if (!Intersection) + return; + const unsigned Opcode = Intersection->getOpcode(); + VecOp->copyIRFlags(Intersection); + for (auto *V : VL) { + auto *Instr = dyn_cast(V); + if (!Instr) + continue; + if (OpValue == nullptr || Opcode == Instr->getOpcode()) + VecOp->andIRFlags(V); } }