diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -590,12 +590,12 @@ // Note that we should not do this for pointer<->integer casts, // because that would result in type punning. if (LI.hasOneUse()) { - // Don't transform when the type is x86_amx, it makes the pass that lower - // x86_amx type happy. if (auto *BC = dyn_cast(LI.user_back())) { - assert(!LI.getType()->isX86_AMXTy() && - "load from x86_amx* should not happen!"); - if (BC->getType()->isX86_AMXTy()) + // Prevent to bitcast the pointer if the cast is not lossless. + auto *DestPtrTy = + BC->getType()->getPointerTo(LI.getPointerAddressSpace()); + auto *SrcPtrTy = LI.getPointerOperandType(); + if (!SrcPtrTy->canLosslesslyBitCastTo(DestPtrTy)) return nullptr; } @@ -1121,12 +1121,10 @@ // Fold away bit casts of the stored value by storing the original type. if (auto *BC = dyn_cast(V)) { - assert(!BC->getType()->isX86_AMXTy() && - "store to x86_amx* should not happen!"); V = BC->getOperand(0); - // Don't transform when the type is x86_amx, it makes the pass that lower - // x86_amx type happy. - if (V->getType()->isX86_AMXTy()) + auto *DestPtrTy = V->getType()->getPointerTo(SI.getPointerAddressSpace()); + auto *SrcPtrTy = SI.getPointerOperandType(); + if (!SrcPtrTy->canLosslesslyBitCastTo(DestPtrTy)) return false; if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) { combineStoreToNewValue(IC, SI, V);