diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h --- a/llvm/include/llvm/IR/Instruction.h +++ b/llvm/include/llvm/IR/Instruction.h @@ -640,6 +640,9 @@ /// Return true if this instruction has a volatile memory access. bool isVolatile() const LLVM_READONLY; + /// Return the type this instruction accesses in memory, if any. + Type *getAccessType() const LLVM_READONLY; + /// Return true if this instruction may throw an exception. /// /// If IncludePhaseOneUnwind is set, this will also include cases where diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp --- a/llvm/lib/IR/Instruction.cpp +++ b/llvm/lib/IR/Instruction.cpp @@ -743,6 +743,42 @@ } } +Type *Instruction::getAccessType() const { + switch (getOpcode()) { + case Instruction::Store: + return getPointerOperand(this)->getType(); + case Instruction::Load: + case Instruction::AtomicRMW: + return getType(); + case Instruction::AtomicCmpXchg: + return cast(this)->getNewValOperand()->getType(); + case Instruction::Call: + case Instruction::Invoke: + if (const IntrinsicInst *II = dyn_cast(this)) { + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + case Intrinsic::masked_gather: + case Intrinsic::masked_expandload: + case Intrinsic::vp_load: + case Intrinsic::vp_gather: + case Intrinsic::experimental_vp_strided_load: + return II->getType(); + case Intrinsic::masked_store: + case Intrinsic::masked_scatter: + case Intrinsic::masked_compressstore: + case Intrinsic::vp_store: + case Intrinsic::vp_scatter: + case Intrinsic::experimental_vp_strided_store: + return II->getOperand(0)->getType(); + default: + break; + } + } + } + + return nullptr; +} + static bool canUnwindPastLandingPad(const LandingPadInst *LP, bool IncludePhaseOneUnwind) { // Because phase one unwinding skips cleanup landingpads, we effectively diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -819,17 +819,7 @@ Type *getValueType() const { // TODO: handle target-specific intrinsics. - if (IntrinsicInst *II = dyn_cast(Inst)) { - switch (II->getIntrinsicID()) { - case Intrinsic::masked_load: - return II->getType(); - case Intrinsic::masked_store: - return II->getArgOperand(0)->getType(); - default: - return nullptr; - } - } - return getLoadStoreType(Inst); + return Inst->getAccessType(); } bool mayReadFromMemory() const { diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -896,22 +896,23 @@ /// Return the type of the memory being accessed. static MemAccessTy getAccessType(const TargetTransformInfo &TTI, Instruction *Inst, Value *OperandVal) { - MemAccessTy AccessTy(Inst->getType(), MemAccessTy::UnknownAddressSpace); + MemAccessTy AccessTy(Inst->getAccessType(), MemAccessTy::UnknownAddressSpace); + assert(AccessTy.MemTy); if (const StoreInst *SI = dyn_cast(Inst)) { - AccessTy.MemTy = SI->getOperand(0)->getType(); AccessTy.AddrSpace = SI->getPointerAddressSpace(); } else if (const LoadInst *LI = dyn_cast(Inst)) { AccessTy.AddrSpace = LI->getPointerAddressSpace(); } else if (const AtomicRMWInst *RMW = dyn_cast(Inst)) { AccessTy.AddrSpace = RMW->getPointerAddressSpace(); - } else if (const AtomicCmpXchgInst *CmpX = dyn_cast(Inst)) { + } else if (const AtomicCmpXchgInst *CmpX = + dyn_cast(Inst)) { AccessTy.AddrSpace = CmpX->getPointerAddressSpace(); } else if (IntrinsicInst *II = dyn_cast(Inst)) { switch (II->getIntrinsicID()) { case Intrinsic::prefetch: case Intrinsic::memset: - AccessTy.AddrSpace = II->getArgOperand(0)->getType()->getPointerAddressSpace(); - AccessTy.MemTy = OperandVal->getType(); + AccessTy.AddrSpace = + II->getArgOperand(0)->getType()->getPointerAddressSpace(); break; case Intrinsic::memmove: case Intrinsic::memcpy: @@ -923,7 +924,6 @@ II->getArgOperand(0)->getType()->getPointerAddressSpace(); break; case Intrinsic::masked_store: - AccessTy.MemTy = II->getOperand(0)->getType(); AccessTy.AddrSpace = II->getArgOperand(1)->getType()->getPointerAddressSpace(); break;