diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -698,8 +698,25 @@ ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI) : Inst(Inst) { if (IntrinsicInst *II = dyn_cast(Inst)) { + IntrID = II->getIntrinsicID(); if (TTI.getTgtMemIntrinsic(II, Info)) - IntrID = II->getIntrinsicID(); + return; + if (isHandledNonTargetIntrinsic(IntrID)) { + switch (IntrID) { + case Intrinsic::masked_load: + Info.PtrVal = Inst->getOperand(0); + Info.ReadMem = true; + Info.WriteMem = false; + Info.IsVolatile = false; + break; + case Intrinsic::masked_store: + Info.PtrVal = Inst->getOperand(1); + Info.ReadMem = false; + Info.WriteMem = true; + Info.IsVolatile = false; + break; + } + } } } @@ -797,6 +814,22 @@ Instruction *Inst; }; + // This function is to prevent accidentally passing a non-target + // intrinsic ID to TargetTransformInfo. + static bool isHandledNonTargetIntrinsic(Intrinsic::ID ID) { + switch (ID) { + case Intrinsic::masked_load: + case Intrinsic::masked_store: + return true; + } + return false; + } + static bool isHandledNonTargetIntrinsic(Value *V) { + if (auto *II = dyn_cast(V)) + return isHandledNonTargetIntrinsic(II->getIntrinsicID()); + return false; + } + bool processNode(DomTreeNode *Node); bool handleBranchCondition(Instruction *CondInst, const BranchInst *BI, @@ -811,8 +844,21 @@ if (auto *SI = dyn_cast(Inst)) return SI->getValueOperand(); assert(isa(Inst) && "Instruction not supported"); - return TTI.getOrCreateResultFromMemIntrinsic(cast(Inst), - ExpectedType); + auto *II = cast(Inst); + if (isHandledNonTargetIntrinsic(II->getIntrinsicID())) + return getOrCreateResultNonTargetMemIntrinsic(II, ExpectedType); + return TTI.getOrCreateResultFromMemIntrinsic(II, ExpectedType); + } + + Value *getOrCreateResultNonTargetMemIntrinsic(IntrinsicInst *II, + Type *ExpectedType) const { + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + return II; + case Intrinsic::masked_store: + return II->getOperand(0); + } + return nullptr; } /// Return true if the instruction is known to only operate on memory @@ -822,6 +868,87 @@ bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration, Instruction *EarlierInst, Instruction *LaterInst); + bool isNonTargetIntrinsicMatch(IntrinsicInst *Earlier, IntrinsicInst *Later) { + auto IsSubmask = [](Value *Mask0, Value *Mask1) { + // Is Mask0 a submask of Mask1? + if (Mask0 == Mask1) + return true; + assert(!isa(Mask0) && !isa(Mask1)); + auto *Vec0 = dyn_cast(Mask0); + auto *Vec1 = dyn_cast(Mask1); + if (!Vec0 || !Vec1 || Vec0->getType() != Vec1->getType()) + return false; + for (int i = 0, e = Vec0->getNumOperands(); i != e; ++i) { + bool M0 = cast(Vec0->getOperand(i))->getZExtValue(); + bool M1 = cast(Vec1->getOperand(i))->getZExtValue(); + if (M0 && !M1) + return false; + } + return true; + }; + auto PtrOp = [](IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(0); + if (II->getIntrinsicID() == Intrinsic::masked_store) + return II->getOperand(1); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + auto MaskOp = [](IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(2); + if (II->getIntrinsicID() == Intrinsic::masked_store) + return II->getOperand(3); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + auto ThruOp = [](IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(3); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + + if (PtrOp(Earlier) != PtrOp(Later)) + return false; + + Intrinsic::ID IDE = Earlier->getIntrinsicID(); + Intrinsic::ID IDL = Later->getIntrinsicID(); + // We could really use specific intrinsic classes for masked loads + // and stores in IntrinsicInst.h. + if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_load) { + // Trying to replace later masked load with the earlier one. + // Check that the pointers are the same, and + // - masks and pass-throughs are the same, or + // - replacee's pass-through is "undef" and replacer's mask is a + // super-set of the replacee's mask. + if (MaskOp(Earlier) == MaskOp(Later) && ThruOp(Earlier) == ThruOp(Later)) + return true; + if (!isa(ThruOp(Later))) + return false; + return IsSubmask(MaskOp(Later), MaskOp(Earlier)); + } + if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_load) { + // Trying to replace a load of a stored value with the store's value. + // Check that the pointers are the same, and + // - load's mask is a subset of store's mask, and + // - load's pass-through is "undef". + if (!IsSubmask(MaskOp(Later), MaskOp(Earlier))) + return false; + return isa(ThruOp(Later)); + } + if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_store) { + // Trying to remove a store of the loaded value. + // Check that the pointers are the same, and + // - store's mask is a subset of the load's mask. + return IsSubmask(MaskOp(Later), MaskOp(Earlier)); + } + if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_store) { + // Trying to remove a dead store (Earlier). + // Check that the pointers are the same, + // - the earlier store's mask is a subset of the later store's mask. + return IsSubmask(MaskOp(Earlier), MaskOp(Later)); + } + return false; + } + void removeMSSA(Instruction &Inst) { if (!MSSA) return; @@ -984,6 +1111,17 @@ Instruction *Earlier = InValFirst ? InVal.DefInst : MemInst.get(); Instruction *Later = InValFirst ? MemInst.get() : InVal.DefInst; + // Deal with non-target memory intrinsics. + bool EarlierNTI = isHandledNonTargetIntrinsic(Earlier); + bool LaterNTI = isHandledNonTargetIntrinsic(Later); + if (EarlierNTI != LaterNTI) + return nullptr; + if (EarlierNTI && LaterNTI) { + if (!isNonTargetIntrinsicMatch(cast(Earlier), + cast(Later))) + return nullptr; + } + if (!isOperatingOnInvariantMemAt(MemInst.get(), InVal.Generation) && !isSameMemGeneration(InVal.Generation, CurrentGeneration, InVal.DefInst, MemInst.get())) diff --git a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll --- a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll +++ b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll @@ -5,8 +5,7 @@ ; CHECK-LABEL: @f0( ; CHECK-NEXT: [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]] ; CHECK-NEXT: call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[A1]], <128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]]) -; CHECK-NEXT: [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]], <128 x i8> undef) -; CHECK-NEXT: ret <128 x i8> [[V1]] +; CHECK-NEXT: ret <128 x i8> [[A1]] ; %v0 = icmp eq <128 x i8> %a1, %a2 call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> %a1, <128 x i8>* %a0, i32 4, <128 x i1> %v0) @@ -18,7 +17,6 @@ ; CHECK-LABEL: @f1( ; CHECK-NEXT: [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]] ; CHECK-NEXT: [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]], <128 x i8> undef) -; CHECK-NEXT: call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[V1]], <128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]]) ; CHECK-NEXT: ret <128 x i8> [[V1]] ; %v0 = icmp eq <128 x i8> %a1, %a2