diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -2421,14 +2421,6 @@ /// `ptrtoint(gep , * null, i32 1>` /// under the right conditions determined by DataLayout. struct VScaleVal_match { -private: - template - inline BinaryOp_match - m_OffsetGep(const Base &B, const Offset &O) { - return BinaryOp_match(B, O); - } - -public: const DataLayout &DL; VScaleVal_match(const DataLayout &DL) : DL(DL) {} @@ -2436,12 +2428,16 @@ if (m_Intrinsic().match(V)) return true; - if (m_PtrToInt(m_OffsetGep(m_Zero(), m_SpecificInt(1))).match(V)) { - auto *GEP = cast(cast(V)->getOperand(0)); - auto *DerefTy = GEP->getSourceElementType(); - if (isa(DerefTy) && - DL.getTypeAllocSizeInBits(DerefTy).getKnownMinSize() == 8) - return true; + Value *Ptr; + if (m_PtrToInt(m_Value(Ptr)).match(V)) { + if (auto *GEP = dyn_cast(Ptr)) { + auto *DerefTy = GEP->getSourceElementType(); + if (GEP->getNumIndices() == 1 && isa(DerefTy) && + m_Zero().match(GEP->getPointerOperand()) && + m_SpecificInt(1).match(GEP->idx_begin()->get()) && + DL.getTypeAllocSizeInBits(DerefTy).getKnownMinSize() == 8) + return true; + } } return false; diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -1636,6 +1636,26 @@ EXPECT_FALSE(match(IRB.getInt64(99), m_InsertValue<0>(m_Value(), m_Value()))); } +TEST_F(PatternMatchTest, VScale) { + DataLayout DL = M->getDataLayout(); + + Type *VecTy = ScalableVectorType::get(IRB.getInt8Ty(), 1); + Type *VecPtrTy = VecTy->getPointerTo(); + Value *NullPtrVec = Constant::getNullValue(VecPtrTy); + Value *GEP = IRB.CreateGEP(VecTy, NullPtrVec, IRB.getInt64(1)); + Value *PtrToInt = IRB.CreatePtrToInt(GEP, DL.getIntPtrType(GEP->getType())); + EXPECT_TRUE(match(PtrToInt, m_VScale(DL))); + + // Prior to this patch, this case would cause assertion failures when attempting to match m_VScale + Type *VecTy2 = ScalableVectorType::get(IRB.getInt8Ty(), 2); + Value *NullPtrVec2 = Constant::getNullValue(VecTy2->getPointerTo()); + Value *BitCast = IRB.CreateBitCast(NullPtrVec2, VecPtrTy); + Value *GEP2 = IRB.CreateGEP(VecTy, BitCast, IRB.getInt64(1)); + Value *PtrToInt2 = + IRB.CreatePtrToInt(GEP2, DL.getIntPtrType(GEP2->getType())); + EXPECT_FALSE(match(PtrToInt2, m_VScale(DL))); +} + template struct MutableConstTest : PatternMatchTest { }; typedef ::testing::Types,