diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -774,21 +774,91 @@ }); } +/// Helper class to indicate whether a vector index can be safely scalarized and +/// if a freeze needs to be inserted. +class ScalarizationResult { + enum class StatusTy { Unsafe, Safe, SafeWithFreeze }; + + StatusTy Status; + Value *ToFreeze; + + ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr) + : Status(Status), ToFreeze(ToFreeze) {} + +public: + ScalarizationResult(const ScalarizationResult &Other) = default; + ~ScalarizationResult() { + assert(!ToFreeze && "freeze() not called with ToFreeze being set"); + } + + static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; } + static ScalarizationResult safe() { return {StatusTy::Safe}; } + static ScalarizationResult safeWithFreeze(Value *ToFreeze) { + return {StatusTy::SafeWithFreeze, ToFreeze}; + } + + /// Returns true if the index can be scalarize without requiring a freeze. + bool isSafe() const { return Status == StatusTy::Safe; } + /// Returns true if the index cannot be scalarized. + bool isUnsafe() const { return Status == StatusTy::Unsafe; } + /// Returns true if the index can be scalarize, but requires inserting a + /// freeze. + bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; } + + /// Freeze the ToFreeze and update the use in \p User to use it. + void freeze(IRBuilder<> &Builder, Instruction &UserI) { + assert(isSafeWithFreeze() && + "should only be used when freezing is required"); + assert(is_contained(ToFreeze->users(), &UserI) && + "UserI must be a user of ToFreeze"); + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(cast(&UserI)); + Value *Frozen = + Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen"); + for (Use &U : make_early_inc_range((UserI.operands()))) + if (U.get() == ToFreeze) + U.set(Frozen); + + ToFreeze = nullptr; + } +}; + /// Check if it is legal to scalarize a memory access to \p VecTy at index \p /// Idx. \p Idx must access a valid vector element. -static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx, - Instruction *CtxI, AssumptionCache &AC) { - if (auto *C = dyn_cast(Idx)) - return C->getValue().ult(VecTy->getNumElements()); - - if (!isGuaranteedNotToBePoison(Idx, &AC)) - return false; +static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy, + Value *Idx, Instruction *CtxI, + AssumptionCache &AC) { + if (auto *C = dyn_cast(Idx)) { + if (C->getValue().ult(VecTy->getNumElements())) + return ScalarizationResult::safe(); + return ScalarizationResult::unsafe(); + } - APInt Zero(Idx->getType()->getScalarSizeInBits(), 0); - APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements()); + unsigned IntWidth = Idx->getType()->getScalarSizeInBits(); + APInt Zero(IntWidth, 0); + APInt MaxElts(IntWidth, VecTy->getNumElements()); ConstantRange ValidIndices(Zero, MaxElts); - ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0); - return ValidIndices.contains(IdxRange); + ConstantRange IdxRange(IntWidth, true); + + if (isGuaranteedNotToBePoison(Idx, &AC)) { + if (ValidIndices.contains(computeConstantRange(Idx, true, &AC, CtxI, 0))) + return ScalarizationResult::safe(); + return ScalarizationResult::unsafe(); + } + + // If the index may be poison, check if we can insert a freeze before the + // range of the index is restricted. + Value *IdxBase; + ConstantInt *CI; + if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) { + IdxRange = IdxRange.binaryAnd(CI->getValue()); + } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) { + IdxRange = IdxRange.urem(CI->getValue()); + } + + if (ValidIndices.contains(IdxRange)) + return ScalarizationResult::safeWithFreeze(IdxBase); + return ScalarizationResult::unsafe(); } /// The memory operation on a vector of \p ScalarType had alignment of @@ -836,12 +906,17 @@ // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || !DL.typeSizeEqualsStoreSize(Load->getType()) || - !canScalarizeAccess(VecTy, Idx, Load, AC) || - SrcAddr != SI->getPointerOperand()->stripPointerCasts() || + SrcAddr != SI->getPointerOperand()->stripPointerCasts()) + return false; + + auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC); + if (ScalarizableIdx.isUnsafe() || isMemModifiedBetween(Load->getIterator(), SI->getIterator(), MemoryLocation::get(SI), AA)) return false; + if (ScalarizableIdx.isSafeWithFreeze()) + ScalarizableIdx.freeze(Builder, *cast(Idx)); Value *GEP = Builder.CreateInBoundsGEP( SI->getValueOperand()->getType(), SI->getPointerOperand(), {ConstantInt::get(Idx->getType(), 0), Idx}); @@ -912,8 +987,11 @@ else if (LastCheckedInst->comesBefore(UI)) LastCheckedInst = UI; - if (!canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC)) + auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC); + if (!ScalarIdx.isSafe()) { + // TODO: Freeze index if it is safe to do so. return false; + } auto *Index = dyn_cast(UI->getOperand(1)); OriginalCost += diff --git a/llvm/test/Transforms/VectorCombine/load-insert-store.ll b/llvm/test/Transforms/VectorCombine/load-insert-store.ll --- a/llvm/test/Transforms/VectorCombine/load-insert-store.ll +++ b/llvm/test/Transforms/VectorCombine/load-insert-store.ll @@ -310,10 +310,10 @@ define void @insert_store_nonconst_index_known_valid_by_and_but_may_be_poison(<16 x i8>* %q, i8 zeroext %s, i32 %idx) { ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_and_but_may_be_poison( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 -; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i32 [[IDX:%.*]], 7 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = freeze i32 [[IDX:%.*]] +; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i32 [[TMP0]], 7 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP1]], align 1 ; CHECK-NEXT: ret void ; entry: @@ -413,10 +413,10 @@ define void @insert_store_nonconst_index_known_valid_by_urem_but_may_be_poison(<16 x i8>* %q, i8 zeroext %s, i32 %idx) { ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_urem_but_may_be_poison( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 -; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i32 [[IDX:%.*]], 16 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = freeze i32 [[IDX:%.*]] +; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i32 [[TMP0]], 16 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP1]], align 1 ; CHECK-NEXT: ret void ; entry: