diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -11746,14 +11746,14 @@ static Value *EmitX86MaskedLoad(CodeGenFunction &CGF, ArrayRef Ops, Align Alignment) { // Cast the pointer to right type. - Value *Ptr = CGF.Builder.CreateBitCast(Ops[0], - llvm::PointerType::getUnqual(Ops[1]->getType())); + llvm::Type *Ty = Ops[1]->getType(); + Value *Ptr = + CGF.Builder.CreateBitCast(Ops[0], llvm::PointerType::getUnqual(Ty)); Value *MaskVec = getMaskVecValue( - CGF, Ops[2], - cast(Ops[1]->getType())->getNumElements()); + CGF, Ops[2], cast(Ty)->getNumElements()); - return CGF.Builder.CreateMaskedLoad(Ptr, Alignment, MaskVec, Ops[1]); + return CGF.Builder.CreateMaskedLoad(Ty, Ptr, Alignment, MaskVec, Ops[1]); } static Value *EmitX86ExpandLoad(CodeGenFunction &CGF, diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h --- a/llvm/include/llvm/IR/IRBuilder.h +++ b/llvm/include/llvm/IR/IRBuilder.h @@ -752,7 +752,7 @@ CallInst *CreateInvariantStart(Value *Ptr, ConstantInt *Size = nullptr); /// Create a call to Masked Load intrinsic - CallInst *CreateMaskedLoad(Value *Ptr, Align Alignment, Value *Mask, + CallInst *CreateMaskedLoad(Type *Ty, Value *Ptr, Align Alignment, Value *Mask, Value *PassThru = nullptr, const Twine &Name = ""); /// Create a call to Masked Store intrinsic @@ -760,7 +760,7 @@ Value *Mask); /// Create a call to Masked Gather intrinsic - CallInst *CreateMaskedGather(Value *Ptrs, Align Alignment, + CallInst *CreateMaskedGather(Type *Ty, Value *Ptrs, Align Alignment, Value *Mask = nullptr, Value *PassThru = nullptr, const Twine &Name = ""); diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -1421,10 +1421,9 @@ return Builder.CreateAlignedLoad(ValTy, Ptr, Alignment); // Convert the mask from an integer type to a vector of i1. - unsigned NumElts = - cast(Passthru->getType())->getNumElements(); + unsigned NumElts = cast(ValTy)->getNumElements(); Mask = getX86MaskVec(Builder, Mask, NumElts); - return Builder.CreateMaskedLoad(Ptr, Alignment, Mask, Passthru); + return Builder.CreateMaskedLoad(ValTy, Ptr, Alignment, Mask, Passthru); } static Value *upgradeAbs(IRBuilder<> &Builder, CallInst &CI) { diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp --- a/llvm/lib/IR/IRBuilder.cpp +++ b/llvm/lib/IR/IRBuilder.cpp @@ -493,6 +493,7 @@ } /// Create a call to a Masked Load intrinsic. +/// \p Ty - vector type to load /// \p Ptr - base pointer for the load /// \p Alignment - alignment of the source location /// \p Mask - vector of booleans which indicates what vector lanes should @@ -500,16 +501,16 @@ /// \p PassThru - pass-through value that is used to fill the masked-off lanes /// of the result /// \p Name - name of the result variable -CallInst *IRBuilderBase::CreateMaskedLoad(Value *Ptr, Align Alignment, +CallInst *IRBuilderBase::CreateMaskedLoad(Type *Ty, Value *Ptr, Align Alignment, Value *Mask, Value *PassThru, const Twine &Name) { auto *PtrTy = cast(Ptr->getType()); - Type *DataTy = PtrTy->getElementType(); - assert(DataTy->isVectorTy() && "Ptr should point to a vector"); + assert(Ty->isVectorTy() && "Type should be vector"); + assert(PtrTy->isOpaqueOrPointeeTypeMatches(Ty) && "Wrong element type"); assert(Mask && "Mask should not be all-ones (null)"); if (!PassThru) - PassThru = UndefValue::get(DataTy); - Type *OverloadedTypes[] = { DataTy, PtrTy }; + PassThru = UndefValue::get(Ty); + Type *OverloadedTypes[] = { Ty, PtrTy }; Value *Ops[] = {Ptr, getInt32(Alignment.value()), Mask, PassThru}; return CreateMaskedIntrinsic(Intrinsic::masked_load, Ops, OverloadedTypes, Name); @@ -546,6 +547,7 @@ } /// Create a call to a Masked Gather intrinsic. +/// \p Ty - vector type to gather /// \p Ptrs - vector of pointers for loading /// \p Align - alignment for one element /// \p Mask - vector of booleans which indicates what vector lanes should @@ -553,22 +555,27 @@ /// \p PassThru - pass-through value that is used to fill the masked-off lanes /// of the result /// \p Name - name of the result variable -CallInst *IRBuilderBase::CreateMaskedGather(Value *Ptrs, Align Alignment, - Value *Mask, Value *PassThru, +CallInst *IRBuilderBase::CreateMaskedGather(Type *Ty, Value *Ptrs, + Align Alignment, Value *Mask, + Value *PassThru, const Twine &Name) { + auto *VecTy = cast(Ty); + ElementCount NumElts = VecTy->getElementCount(); auto *PtrsTy = cast(Ptrs->getType()); - auto *PtrTy = cast(PtrsTy->getElementType()); - ElementCount NumElts = PtrsTy->getElementCount(); - auto *DataTy = VectorType::get(PtrTy->getElementType(), NumElts); + assert(cast(PtrsTy->getElementType()) + ->isOpaqueOrPointeeTypeMatches( + cast(Ty)->getElementType()) && + "Element type mismatch"); + assert(NumElts == PtrsTy->getElementCount() && "Element count mismatch"); if (!Mask) Mask = Constant::getAllOnesValue( VectorType::get(Type::getInt1Ty(Context), NumElts)); if (!PassThru) - PassThru = UndefValue::get(DataTy); + PassThru = UndefValue::get(Ty); - Type *OverloadedTypes[] = {DataTy, PtrsTy}; + Type *OverloadedTypes[] = {Ty, PtrsTy}; Value *Ops[] = {Ptrs, getInt32(Alignment.value()), Mask, PassThru}; // We specify only one type when we create this intrinsic. Types of other diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp --- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp +++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp @@ -475,7 +475,7 @@ return PassThru; if (Mask == ConstantInt::getTrue(Mask->getType())) return Builder.CreateAlignedLoad(ValTy, Ptr, Align(Alignment)); - return Builder.CreateMaskedLoad(Ptr, Align(Alignment), Mask, PassThru); + return Builder.CreateMaskedLoad(ValTy, Ptr, Align(Alignment), Mask, PassThru); } auto AlignVectors::createAlignedStore(IRBuilder<> &Builder, Value *Val, diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp --- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp +++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp @@ -72,8 +72,8 @@ Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec"); // The pass-through vector for an x86 masked load is a zero vector. - CallInst *NewMaskedLoad = - IC.Builder.CreateMaskedLoad(PtrCast, Align(1), BoolMask, ZeroVec); + CallInst *NewMaskedLoad = IC.Builder.CreateMaskedLoad( + II.getType(), PtrCast, Align(1), BoolMask, ZeroVec); return IC.replaceInstUsesWith(II, NewMaskedLoad); } diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -3117,7 +3117,7 @@ if (PropagateShadow) { std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment, /*isStore*/ false); - setShadow(&I, IRB.CreateMaskedLoad(ShadowPtr, Alignment, Mask, + setShadow(&I, IRB.CreateMaskedLoad(ShadowTy, ShadowPtr, Alignment, Mask, getShadow(PassThru), "_msmaskedld")); } else { setShadow(&I, getCleanShadow(&I)); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -2778,7 +2778,7 @@ : ShuffledMask; } NewLoad = - Builder.CreateMaskedLoad(AddrParts[Part], Group->getAlign(), + Builder.CreateMaskedLoad(VecTy, AddrParts[Part], Group->getAlign(), GroupMask, PoisonVec, "wide.masked.vec"); } else @@ -2990,15 +2990,15 @@ if (CreateGatherScatter) { Value *MaskPart = isMaskRequired ? BlockInMaskParts[Part] : nullptr; Value *VectorGep = State.get(Addr, Part); - NewLI = Builder.CreateMaskedGather(VectorGep, Alignment, MaskPart, + NewLI = Builder.CreateMaskedGather(DataTy, VectorGep, Alignment, MaskPart, nullptr, "wide.masked.gather"); addMetadata(NewLI, LI); } else { auto *VecPtr = CreateVecPtr(Part, State.get(Addr, VPIteration(0, 0))); if (isMaskRequired) NewLI = Builder.CreateMaskedLoad( - VecPtr, Alignment, BlockInMaskParts[Part], PoisonValue::get(DataTy), - "wide.masked.load"); + DataTy, VecPtr, Alignment, BlockInMaskParts[Part], + PoisonValue::get(DataTy), "wide.masked.load"); else NewLI = Builder.CreateAlignedLoad(DataTy, VecPtr, Alignment, "wide.load"); diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -5403,7 +5403,7 @@ for (Value *V : E->Scalars) CommonAlignment = commonAlignment(CommonAlignment, cast(V)->getAlign()); - NewLI = Builder.CreateMaskedGather(VecPtr, CommonAlignment); + NewLI = Builder.CreateMaskedGather(VecTy, VecPtr, CommonAlignment); } Value *V = propagateMetadata(NewLI, E->Scalars); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1516,10 +1516,11 @@ let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; string llvmBuilder = [{ + llvm::Type *Ty = $data->getType()->getPointerElementType(); $res = $pass_thru.empty() ? builder.CreateMaskedLoad( - $data, llvm::Align($alignment), $mask) : + Ty, $data, llvm::Align($alignment), $mask) : builder.CreateMaskedLoad( - $data, llvm::Align($alignment), $mask, $pass_thru[0]); + Ty, $data, llvm::Align($alignment), $mask, $pass_thru[0]); }]; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -1545,10 +1546,14 @@ let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; string llvmBuilder = [{ + llvm::VectorType *PtrVecTy = cast($ptrs->getType()); + llvm::Type *Ty = llvm::VectorType::get( + PtrVecTy->getElementType()->getPointerElementType(), + PtrVecTy->getElementCount()); $res = $pass_thru.empty() ? builder.CreateMaskedGather( - $ptrs, llvm::Align($alignment), $mask) : + Ty, $ptrs, llvm::Align($alignment), $mask) : builder.CreateMaskedGather( - $ptrs, llvm::Align($alignment), $mask, $pass_thru[0]); + Ty, $ptrs, llvm::Align($alignment), $mask, $pass_thru[0]); }]; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";