diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -683,29 +683,52 @@ public: ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI) : Inst(Inst) { - if (IntrinsicInst *II = dyn_cast(Inst)) + if (IntrinsicInst *II = dyn_cast(Inst)) { + IntrID = II->getIntrinsicID(); if (TTI.getTgtMemIntrinsic(II, Info)) - IsTargetMemInst = true; + return; + if (isHandledNonTargetIntrinsic(IntrID)) { + switch (IntrID) { + case Intrinsic::masked_load: + Info.PtrVal = Inst->getOperand(0); + Info.ReadMem = true; + Info.WriteMem = false; + Info.IsVolatile = false; + break; + case Intrinsic::masked_store: + Info.PtrVal = Inst->getOperand(1); + Info.ReadMem = false; + Info.WriteMem = true; + Info.IsVolatile = false; + break; + } + } + } } + Instruction *get() { return Inst; } + const Instruction *get() const { return Inst; } + bool isLoad() const { - if (IsTargetMemInst) return Info.ReadMem; + if (IntrID != 0) + return Info.ReadMem; return isa(Inst); } bool isStore() const { - if (IsTargetMemInst) return Info.WriteMem; + if (IntrID != 0) + return Info.WriteMem; return isa(Inst); } bool isAtomic() const { - if (IsTargetMemInst) + if (IntrID != 0) return Info.Ordering != AtomicOrdering::NotAtomic; return Inst->isAtomic(); } bool isUnordered() const { - if (IsTargetMemInst) + if (IntrID != 0) return Info.isUnordered(); if (LoadInst *LI = dyn_cast(Inst)) { @@ -718,7 +741,7 @@ } bool isVolatile() const { - if (IsTargetMemInst) + if (IntrID != 0) return Info.IsVolatile; if (LoadInst *LI = dyn_cast(Inst)) { @@ -748,44 +771,80 @@ // field in the MemIntrinsicInfo structure. That field contains // non-negative values only. int getMatchingId() const { - if (IsTargetMemInst) return Info.MatchingId; + if (IntrID != 0) + return Info.MatchingId; return -1; } Value *getPointerOperand() const { - if (IsTargetMemInst) return Info.PtrVal; + if (IntrID != 0) + return Info.PtrVal; return getLoadStorePointerOperand(Inst); } bool mayReadFromMemory() const { - if (IsTargetMemInst) return Info.ReadMem; + if (IntrID != 0) + return Info.ReadMem; return Inst->mayReadFromMemory(); } bool mayWriteToMemory() const { - if (IsTargetMemInst) return Info.WriteMem; + if (IntrID != 0) + return Info.WriteMem; return Inst->mayWriteToMemory(); } private: - bool IsTargetMemInst = false; + Intrinsic::ID IntrID = 0; MemIntrinsicInfo Info; Instruction *Inst; }; + // This function is to prevent accidentally passing a non-target + // intrinsic ID to TargetTransformInfo. + static bool isHandledNonTargetIntrinsic(Intrinsic::ID ID) { + switch (ID) { + case Intrinsic::masked_load: + case Intrinsic::masked_store: + return true; + } + return false; + } + static bool isHandledNonTargetIntrinsic(Value *V) { + if (auto *II = dyn_cast(V)) + return isHandledNonTargetIntrinsic(II->getIntrinsicID()); + return false; + } + bool processNode(DomTreeNode *Node); bool handleBranchCondition(Instruction *CondInst, const BranchInst *BI, const BasicBlock *BB, const BasicBlock *Pred); + Value *getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst, + unsigned CurrentGeneration, bool InValFirst); + Value *getOrCreateResult(Value *Inst, Type *ExpectedType) const { if (auto *LI = dyn_cast(Inst)) return LI; if (auto *SI = dyn_cast(Inst)) return SI->getValueOperand(); assert(isa(Inst) && "Instruction not supported"); - return TTI.getOrCreateResultFromMemIntrinsic(cast(Inst), - ExpectedType); + auto *II = cast(Inst); + if (isHandledNonTargetIntrinsic(II->getIntrinsicID())) + return getOrCreateResultNonTargetMemIntrinsic(II, ExpectedType); + return TTI.getOrCreateResultFromMemIntrinsic(II, ExpectedType); + } + + Value *getOrCreateResultNonTargetMemIntrinsic(IntrinsicInst *II, + Type *ExpectedType) const { + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + return II; + case Intrinsic::masked_store: + return II->getOperand(0); + } + return nullptr; } /// Return true if the instruction is known to only operate on memory @@ -795,6 +854,87 @@ bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration, Instruction *EarlierInst, Instruction *LaterInst); + bool isNonTargetIntrinsicMatch(IntrinsicInst *Earlier, IntrinsicInst *Later) { + auto IsSubmask = [](Value *Mask0, Value *Mask1) { + // Is Mask0 a submask of Mask1? + if (Mask0 == Mask1) + return true; + assert(!isa(Mask0) && !isa(Mask1)); + auto *Vec0 = dyn_cast(Mask0); + auto *Vec1 = dyn_cast(Mask1); + if (!Vec0 || !Vec1 || Vec0->getType() != Vec1->getType()) + return false; + for (int i = 0, e = Vec0->getNumOperands(); i != e; ++i) { + bool M0 = cast(Vec0->getOperand(i))->getZExtValue(); + bool M1 = cast(Vec1->getOperand(i))->getZExtValue(); + if (M0 && !M1) + return false; + } + return true; + }; + auto PtrOp = [](IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(0); + if (II->getIntrinsicID() == Intrinsic::masked_store) + return II->getOperand(1); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + auto MaskOp = [](IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(2); + if (II->getIntrinsicID() == Intrinsic::masked_store) + return II->getOperand(3); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + auto ThruOp = [](IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(3); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + + if (PtrOp(Earlier) != PtrOp(Later)) + return false; + + Intrinsic::ID IDE = Earlier->getIntrinsicID(); + Intrinsic::ID IDL = Later->getIntrinsicID(); + // We could really use specific intrinsic classes for masked loads + // and stores in IntrinsicInst.h. + if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_load) { + // Trying to replace later masked load with the earlier one. + // Check that the pointers are the same, and + // - masks and pass-throughs are the same, or + // - replacee's pass-through is "undef" and replacer's mask is a + // super-set of the replacee's mask. + if (MaskOp(Earlier) == MaskOp(Later) && ThruOp(Earlier) == ThruOp(Later)) + return true; + if (!isa(ThruOp(Later))) + return false; + return IsSubmask(MaskOp(Later), MaskOp(Earlier)); + } + if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_load) { + // Trying to replace a load of a stored value with the store's value. + // Check that the pointers are the same, and + // - load's mask is a subset of store's mask, and + // - load's pass-through is "undef". + if (!IsSubmask(MaskOp(Later), MaskOp(Earlier))) + return false; + return isa(ThruOp(Later)); + } + if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_store) { + // Trying to remove a store of the loaded value. + // Check that the pointers are the same, and + // - store's mask is a subset of the load's mask. + return IsSubmask(MaskOp(Later), MaskOp(Earlier)); + } + if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_store) { + // Trying to remove a dead store (Earlier). + // Check that the pointers are the same, + // - the earlier store's mask is a subset of the later store's mask. + return IsSubmask(MaskOp(Earlier), MaskOp(Later)); + } + return false; + } + void removeMSSA(Instruction &Inst) { if (!MSSA) return; @@ -940,6 +1080,41 @@ return MadeChanges; } +Value *EarlyCSE::getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst, + unsigned CurrentGeneration, bool InValFirst) { + // InValFirst is false for loads, true for stores. + + if (InVal.DefInst == nullptr) + return nullptr; + if (InVal.MatchingId != MemInst.getMatchingId()) + return nullptr; + // We don't yet handle removing loads with ordering of any kind. + if (MemInst.isVolatile() || !MemInst.isUnordered()) + return nullptr; + // We can't replace an atomic load with one which isn't also atomic. + if (MemInst.isLoad() && !InVal.IsAtomic && MemInst.isAtomic()) + return nullptr; + Instruction *Earlier = InValFirst ? InVal.DefInst : MemInst.get(); + Instruction *Later = InValFirst ? MemInst.get() : InVal.DefInst; + + // Deal with non-target memory intrinsics. + bool EarlierNTI = isHandledNonTargetIntrinsic(Earlier); + bool LaterNTI = isHandledNonTargetIntrinsic(Later); + if (EarlierNTI != LaterNTI) + return nullptr; + if (EarlierNTI && LaterNTI) { + if (!isNonTargetIntrinsicMatch(cast(Earlier), + cast(Later))) + return nullptr; + } + + if (!isOperatingOnInvariantMemAt(MemInst.get(), InVal.Generation) && + !isSameMemGeneration(InVal.Generation, CurrentGeneration, InVal.DefInst, + MemInst.get())) + return nullptr; + return getOrCreateResult(Later, Earlier->getType()); +} + bool EarlyCSE::processNode(DomTreeNode *Node) { bool Changed = false; BasicBlock *BB = Node->getBlock(); @@ -1156,32 +1331,22 @@ // we can assume the current load loads the same value as the dominating // load. LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); - if (InVal.DefInst != nullptr && - InVal.MatchingId == MemInst.getMatchingId() && - // We don't yet handle removing loads with ordering of any kind. - !MemInst.isVolatile() && MemInst.isUnordered() && - // We can't replace an atomic load with one which isn't also atomic. - InVal.IsAtomic >= MemInst.isAtomic() && - (isOperatingOnInvariantMemAt(&Inst, InVal.Generation) || - isSameMemGeneration(InVal.Generation, CurrentGeneration, - InVal.DefInst, &Inst))) { - Value *Op = getOrCreateResult(InVal.DefInst, Inst.getType()); - if (Op != nullptr) { - LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << Inst - << " to: " << *InVal.DefInst << '\n'); - if (!DebugCounter::shouldExecute(CSECounter)) { - LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); - continue; - } - if (!Inst.use_empty()) - Inst.replaceAllUsesWith(Op); - salvageKnowledge(&Inst, &AC); - removeMSSA(Inst); - Inst.eraseFromParent(); - Changed = true; - ++NumCSELoad; + if (Value *Op = getMatchingValue(InVal, MemInst, CurrentGeneration, + false)) { + LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << Inst + << " to: " << *InVal.DefInst << '\n'); + if (!DebugCounter::shouldExecute(CSECounter)) { + LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } + if (!Inst.use_empty()) + Inst.replaceAllUsesWith(Op); + salvageKnowledge(&Inst, &AC); + removeMSSA(Inst); + Inst.eraseFromParent(); + Changed = true; + ++NumCSELoad; + continue; } // Otherwise, remember that we have this instruction. @@ -1251,13 +1416,8 @@ if (MemInst.isValid() && MemInst.isStore()) { LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); if (InVal.DefInst && - InVal.DefInst == getOrCreateResult(&Inst, InVal.DefInst->getType()) && - InVal.MatchingId == MemInst.getMatchingId() && - // We don't yet handle removing stores with ordering of any kind. - !MemInst.isVolatile() && MemInst.isUnordered() && - (isOperatingOnInvariantMemAt(&Inst, InVal.Generation) || - isSameMemGeneration(InVal.Generation, CurrentGeneration, - InVal.DefInst, &Inst))) { + InVal.DefInst == getMatchingValue(InVal, MemInst, CurrentGeneration, + true)) { // It is okay to have a LastStore to a different pointer here if MemorySSA // tells us that the load and store are from the same memory generation. // In that case, LastStore should keep its present value since we're diff --git a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll --- a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll +++ b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll @@ -5,8 +5,7 @@ ; CHECK-LABEL: @f0( ; CHECK-NEXT: [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]] ; CHECK-NEXT: call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[A1]], <128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]]) -; CHECK-NEXT: [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]], <128 x i8> undef) -; CHECK-NEXT: ret <128 x i8> [[V1]] +; CHECK-NEXT: ret <128 x i8> [[A1]] ; %v0 = icmp eq <128 x i8> %a1, %a2 call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> %a1, <128 x i8>* %a0, i32 4, <128 x i1> %v0) @@ -18,7 +17,6 @@ ; CHECK-LABEL: @f1( ; CHECK-NEXT: [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]] ; CHECK-NEXT: [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]], <128 x i8> undef) -; CHECK-NEXT: call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[V1]], <128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]]) ; CHECK-NEXT: ret <128 x i8> [[V1]] ; %v0 = icmp eq <128 x i8> %a1, %a2