diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h --- a/llvm/include/llvm/IR/Instruction.h +++ b/llvm/include/llvm/IR/Instruction.h @@ -647,6 +647,9 @@ /// Return true if this instruction has a volatile memory access. bool isVolatile() const LLVM_READONLY; + /// Return the type this instruction accesses in memory, if any. + Type *getAccessType() const LLVM_READONLY; + /// Return true if this instruction may throw an exception. /// /// If IncludePhaseOneUnwind is set, this will also include cases where diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp --- a/llvm/lib/IR/Instruction.cpp +++ b/llvm/lib/IR/Instruction.cpp @@ -743,6 +743,42 @@ } } +Type *Instruction::getAccessType() const { + switch (getOpcode()) { + case Instruction::Store: + return cast(this)->getValueOperand()->getType(); + case Instruction::Load: + case Instruction::AtomicRMW: + return getType(); + case Instruction::AtomicCmpXchg: + return cast(this)->getNewValOperand()->getType(); + case Instruction::Call: + case Instruction::Invoke: + if (const IntrinsicInst *II = dyn_cast(this)) { + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + case Intrinsic::masked_gather: + case Intrinsic::masked_expandload: + case Intrinsic::vp_load: + case Intrinsic::vp_gather: + case Intrinsic::experimental_vp_strided_load: + return II->getType(); + case Intrinsic::masked_store: + case Intrinsic::masked_scatter: + case Intrinsic::masked_compressstore: + case Intrinsic::vp_store: + case Intrinsic::vp_scatter: + case Intrinsic::experimental_vp_strided_store: + return II->getOperand(0)->getType(); + default: + break; + } + } + } + + return nullptr; +} + static bool canUnwindPastLandingPad(const LandingPadInst *LP, bool IncludePhaseOneUnwind) { // Because phase one unwinding skips cleanup landingpads, we effectively diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -819,17 +819,7 @@ Type *getValueType() const { // TODO: handle target-specific intrinsics. - if (IntrinsicInst *II = dyn_cast(Inst)) { - switch (II->getIntrinsicID()) { - case Intrinsic::masked_load: - return II->getType(); - case Intrinsic::masked_store: - return II->getArgOperand(0)->getType(); - default: - return nullptr; - } - } - return getLoadStoreType(Inst); + return Inst->getAccessType(); } bool mayReadFromMemory() const { diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -896,9 +896,14 @@ /// Return the type of the memory being accessed. static MemAccessTy getAccessType(const TargetTransformInfo &TTI, Instruction *Inst, Value *OperandVal) { - MemAccessTy AccessTy(Inst->getType(), MemAccessTy::UnknownAddressSpace); + MemAccessTy AccessTy = MemAccessTy::getUnknown(Inst->getContext()); + + // First get the type of memory being accessed. + if (Type *Ty = Inst->getAccessType()) + AccessTy.MemTy = Ty; + + // Then get the pointer address space. if (const StoreInst *SI = dyn_cast(Inst)) { - AccessTy.MemTy = SI->getOperand(0)->getType(); AccessTy.AddrSpace = SI->getPointerAddressSpace(); } else if (const LoadInst *LI = dyn_cast(Inst)) { AccessTy.AddrSpace = LI->getPointerAddressSpace(); @@ -923,7 +928,6 @@ II->getArgOperand(0)->getType()->getPointerAddressSpace(); break; case Intrinsic::masked_store: - AccessTy.MemTy = II->getOperand(0)->getType(); AccessTy.AddrSpace = II->getArgOperand(1)->getType()->getPointerAddressSpace(); break;