diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -689,8 +689,33 @@ ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI) : Inst(Inst) { if (IntrinsicInst *II = dyn_cast(Inst)) { + IntrID = II->getIntrinsicID(); if (TTI.getTgtMemIntrinsic(II, Info)) - IntrID = II->getIntrinsicID(); + return; + if (isHandledNonTargetIntrinsic(IntrID)) { + switch (IntrID) { + case Intrinsic::masked_load: + Info.PtrVal = Inst->getOperand(0); + Info.MatchingId = Intrinsic::masked_load; + Info.ReadMem = true; + Info.WriteMem = false; + Info.IsVolatile = false; + break; + case Intrinsic::masked_store: + Info.PtrVal = Inst->getOperand(1); + // Use the ID of masked load as the "matching id". This will + // prevent matching non-masked loads/stores with masked ones + // (which could be done), but at the moment, the code here + // does not support matching intrinsics with non-intrinsics, + // so keep the MatchingIds specific to masked instructions + // for now (TODO). + Info.MatchingId = Intrinsic::masked_load; + Info.ReadMem = false; + Info.WriteMem = true; + Info.IsVolatile = false; + break; + } + } } } @@ -747,11 +772,6 @@ return false; } - bool isMatchingMemLoc(const ParseMemoryInst &Inst) const { - return (getPointerOperand() == Inst.getPointerOperand() && - getMatchingId() == Inst.getMatchingId()); - } - bool isValid() const { return getPointerOperand() != nullptr; } // For regular (non-intrinsic) loads/stores, this is set to -1. For @@ -788,6 +808,22 @@ Instruction *Inst; }; + // This function is to prevent accidentally passing a non-target + // intrinsic ID to TargetTransformInfo. + static bool isHandledNonTargetIntrinsic(Intrinsic::ID ID) { + switch (ID) { + case Intrinsic::masked_load: + case Intrinsic::masked_store: + return true; + } + return false; + } + static bool isHandledNonTargetIntrinsic(const Value *V) { + if (auto *II = dyn_cast(V)) + return isHandledNonTargetIntrinsic(II->getIntrinsicID()); + return false; + } + bool processNode(DomTreeNode *Node); bool handleBranchCondition(Instruction *CondInst, const BranchInst *BI, @@ -796,14 +832,30 @@ Value *getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst, unsigned CurrentGeneration); + bool overridingStores(const ParseMemoryInst &Earlier, + const ParseMemoryInst &Later); + Value *getOrCreateResult(Value *Inst, Type *ExpectedType) const { if (auto *LI = dyn_cast(Inst)) return LI; if (auto *SI = dyn_cast(Inst)) return SI->getValueOperand(); assert(isa(Inst) && "Instruction not supported"); - return TTI.getOrCreateResultFromMemIntrinsic(cast(Inst), - ExpectedType); + auto *II = cast(Inst); + if (isHandledNonTargetIntrinsic(II->getIntrinsicID())) + return getOrCreateResultNonTargetMemIntrinsic(II, ExpectedType); + return TTI.getOrCreateResultFromMemIntrinsic(II, ExpectedType); + } + + Value *getOrCreateResultNonTargetMemIntrinsic(IntrinsicInst *II, + Type *ExpectedType) const { + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + return II; + case Intrinsic::masked_store: + return II->getOperand(0); + } + return nullptr; } /// Return true if the instruction is known to only operate on memory @@ -813,6 +865,101 @@ bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration, Instruction *EarlierInst, Instruction *LaterInst); + bool isNonTargetIntrinsicMatch(const IntrinsicInst *Earlier, + const IntrinsicInst *Later) { + auto IsSubmask = [](const Value *Mask0, const Value *Mask1) { + // Is Mask0 a submask of Mask1? + if (Mask0 == Mask1) + return true; + if (isa(Mask0) || isa(Mask1)) + return false; + auto *Vec0 = dyn_cast(Mask0); + auto *Vec1 = dyn_cast(Mask1); + if (!Vec0 || !Vec1) + return false; + assert(Vec0->getType() == Vec1->getType() && + "Masks should have the same type"); + for (int i = 0, e = Vec0->getNumOperands(); i != e; ++i) { + Constant *Elem0 = Vec0->getOperand(i); + Constant *Elem1 = Vec1->getOperand(i); + auto *Int0 = dyn_cast(Elem0); + if (Int0 && Int0->isZero()) + continue; + auto *Int1 = dyn_cast(Elem1); + if (Int1 && !Int1->isZero()) + continue; + if (isa(Elem0) || isa(Elem1)) + return false; + if (Elem0 == Elem1) + continue; + return false; + } + return true; + }; + auto PtrOp = [](const IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(0); + if (II->getIntrinsicID() == Intrinsic::masked_store) + return II->getOperand(1); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + auto MaskOp = [](const IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(2); + if (II->getIntrinsicID() == Intrinsic::masked_store) + return II->getOperand(3); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + auto ThruOp = [](const IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(3); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + + if (PtrOp(Earlier) != PtrOp(Later)) + return false; + + Intrinsic::ID IDE = Earlier->getIntrinsicID(); + Intrinsic::ID IDL = Later->getIntrinsicID(); + // We could really use specific intrinsic classes for masked loads + // and stores in IntrinsicInst.h. + if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_load) { + // Trying to replace later masked load with the earlier one. + // Check that the pointers are the same, and + // - masks and pass-throughs are the same, or + // - replacee's pass-through is "undef" and replacer's mask is a + // super-set of the replacee's mask. + if (MaskOp(Earlier) == MaskOp(Later) && ThruOp(Earlier) == ThruOp(Later)) + return true; + if (!isa(ThruOp(Later))) + return false; + return IsSubmask(MaskOp(Later), MaskOp(Earlier)); + } + if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_load) { + // Trying to replace a load of a stored value with the store's value. + // Check that the pointers are the same, and + // - load's mask is a subset of store's mask, and + // - load's pass-through is "undef". + if (!IsSubmask(MaskOp(Later), MaskOp(Earlier))) + return false; + return isa(ThruOp(Later)); + } + if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_store) { + // Trying to remove a store of the loaded value. + // Check that the pointers are the same, and + // - store's mask is a subset of the load's mask. + return IsSubmask(MaskOp(Later), MaskOp(Earlier)); + } + if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_store) { + // Trying to remove a dead store (earlier). + // Check that the pointers are the same, + // - the to-be-removed store's mask is a subset of the other store's + // mask. + return IsSubmask(MaskOp(Earlier), MaskOp(Later)); + } + return false; + } + void removeMSSA(Instruction &Inst) { if (!MSSA) return; @@ -978,6 +1125,17 @@ Instruction *Matching = MemInstMatching ? MemInst.get() : InVal.DefInst; Instruction *Other = MemInstMatching ? InVal.DefInst : MemInst.get(); + // Deal with non-target memory intrinsics. + bool MatchingNTI = isHandledNonTargetIntrinsic(Matching); + bool OtherNTI = isHandledNonTargetIntrinsic(Other); + if (OtherNTI != MatchingNTI) + return nullptr; + if (OtherNTI && MatchingNTI) { + if (!isNonTargetIntrinsicMatch(cast(InVal.DefInst), + cast(MemInst.get()))) + return nullptr; + } + if (!isOperatingOnInvariantMemAt(MemInst.get(), InVal.Generation) && !isSameMemGeneration(InVal.Generation, CurrentGeneration, InVal.DefInst, MemInst.get())) @@ -985,6 +1143,37 @@ return getOrCreateResult(Matching, Other->getType()); } +bool EarlyCSE::overridingStores(const ParseMemoryInst &Earlier, + const ParseMemoryInst &Later) { + // Can we remove Earlier store because of Later store? + + assert(Earlier.isUnordered() && !Earlier.isVolatile() && + "Violated invariant"); + if (Earlier.getPointerOperand() != Later.getPointerOperand()) + return false; + if (Earlier.getMatchingId() != Later.getMatchingId()) + return false; + // At the moment, we don't remove ordered stores, but do remove + // unordered atomic stores. There's no special requirement (for + // unordered atomics) about removing atomic stores only in favor of + // other atomic stores since we were going to execute the non-atomic + // one anyway and the atomic one might never have become visible. + if (!Earlier.isUnordered() || !Later.isUnordered()) + return false; + + // Deal with non-target memory intrinsics. + bool ENTI = isHandledNonTargetIntrinsic(Earlier.get()); + bool LNTI = isHandledNonTargetIntrinsic(Later.get()); + if (ENTI && LNTI) + return isNonTargetIntrinsicMatch(cast(Earlier.get()), + cast(Later.get())); + + // Because of the check above, at least one of them is false. + // For now disallow matching intrinsics with non-intrinsics, + // so assume that the stores match if neither is an intrinsic. + return ENTI == LNTI; +} + bool EarlyCSE::processNode(DomTreeNode *Node) { bool Changed = false; BasicBlock *BB = Node->getBlock(); @@ -1320,17 +1509,8 @@ if (MemInst.isValid() && MemInst.isStore()) { // We do a trivial form of DSE if there are two stores to the same // location with no intervening loads. Delete the earlier store. - // At the moment, we don't remove ordered stores, but do remove - // unordered atomic stores. There's no special requirement (for - // unordered atomics) about removing atomic stores only in favor of - // other atomic stores since we were going to execute the non-atomic - // one anyway and the atomic one might never have become visible. if (LastStore) { - ParseMemoryInst LastStoreMemInst(LastStore, TTI); - assert(LastStoreMemInst.isUnordered() && - !LastStoreMemInst.isVolatile() && - "Violated invariant"); - if (LastStoreMemInst.isMatchingMemLoc(MemInst)) { + if (overridingStores(ParseMemoryInst(LastStore, TTI), MemInst)) { LLVM_DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore << " due to: " << Inst << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { diff --git a/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll b/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll --- a/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll +++ b/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll @@ -13,8 +13,7 @@ define <4 x i32> @f3(<4 x i32>* %a0, <4 x i32> %a1) { ; CHECK-LABEL: @f3( ; CHECK-NEXT: [[V0:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A0:%.*]], i32 4, <4 x i1> , <4 x i32> [[A1:%.*]]) -; CHECK-NEXT: [[V1:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A0]], i32 4, <4 x i1> , <4 x i32> undef) -; CHECK-NEXT: [[V2:%.*]] = add <4 x i32> [[V0]], [[V1]] +; CHECK-NEXT: [[V2:%.*]] = add <4 x i32> [[V0]], [[V0]] ; CHECK-NEXT: ret <4 x i32> [[V2]] ; %v0 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %a0, i32 4, <4 x i1> , <4 x i32> %a1) @@ -60,8 +59,7 @@ ; Expect the first store to be removed. define void @f6(<4 x i32> %a0, <4 x i32>* %a1) { ; CHECK-LABEL: @f6( -; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0:%.*]], <4 x i32>* [[A1:%.*]], i32 4, <4 x i1> ) -; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0]], <4 x i32>* [[A1]], i32 4, <4 x i1> ) +; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0:%.*]], <4 x i32>* [[A1:%.*]], i32 4, <4 x i1> ) ; CHECK-NEXT: ret void ; call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %a0, <4 x i32>* %a1, i32 4, <4 x i1> ) @@ -90,7 +88,6 @@ define <4 x i32> @f8(<4 x i32>* %a0, <4 x i32> %a1) { ; CHECK-LABEL: @f8( ; CHECK-NEXT: [[V0:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A0:%.*]], i32 4, <4 x i1> , <4 x i32> [[A1:%.*]]) -; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[V0]], <4 x i32>* [[A0]], i32 4, <4 x i1> ) ; CHECK-NEXT: ret <4 x i32> [[V0]] ; %v0 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %a0, i32 4, <4 x i1> , <4 x i32> %a1) @@ -119,8 +116,7 @@ define <4 x i32> @fa(<4 x i32> %a0, <4 x i32>* %a1) { ; CHECK-LABEL: @fa( ; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0:%.*]], <4 x i32>* [[A1:%.*]], i32 4, <4 x i1> ) -; CHECK-NEXT: [[V0:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A1]], i32 4, <4 x i1> , <4 x i32> undef) -; CHECK-NEXT: ret <4 x i32> [[V0]] +; CHECK-NEXT: ret <4 x i32> [[A0]] ; call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %a0, <4 x i32>* %a1, i32 4, <4 x i1> ) %v0 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %a1, i32 4, <4 x i1> , <4 x i32> undef) diff --git a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll --- a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll +++ b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll @@ -5,8 +5,7 @@ ; CHECK-LABEL: @f0( ; CHECK-NEXT: [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]] ; CHECK-NEXT: call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[A1]], <128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]]) -; CHECK-NEXT: [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]], <128 x i8> undef) -; CHECK-NEXT: ret <128 x i8> [[V1]] +; CHECK-NEXT: ret <128 x i8> [[A1]] ; %v0 = icmp eq <128 x i8> %a1, %a2 call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> %a1, <128 x i8>* %a0, i32 4, <128 x i1> %v0) @@ -18,7 +17,6 @@ ; CHECK-LABEL: @f1( ; CHECK-NEXT: [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]] ; CHECK-NEXT: [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]], <128 x i8> undef) -; CHECK-NEXT: call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[V1]], <128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]]) ; CHECK-NEXT: ret <128 x i8> [[V1]] ; %v0 = icmp eq <128 x i8> %a1, %a2