diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -234,6 +234,7 @@ case Intrinsic::memset_element_unordered_atomic: case Intrinsic::init_trampoline: case Intrinsic::lifetime_end: + case Intrinsic::masked_store: return true; } } @@ -257,8 +258,8 @@ /// Return a Location stored to by the specified instruction. If isRemovable /// returns true, this function and getLocForRead completely describe the memory /// operations for this instruction. -static MemoryLocation getLocForWrite(Instruction *Inst) { - +static MemoryLocation getLocForWrite(Instruction *Inst, + const TargetLibraryInfo &TLI) { if (StoreInst *SI = dyn_cast(Inst)) return MemoryLocation::get(SI); @@ -274,6 +275,8 @@ return MemoryLocation(); // Unhandled intrinsic. case Intrinsic::init_trampoline: return MemoryLocation(II->getArgOperand(0)); + case Intrinsic::masked_store: + return MemoryLocation::getForArgument(II, 1, TLI); case Intrinsic::lifetime_end: { uint64_t Len = cast(II->getArgOperand(0))->getZExtValue(); return MemoryLocation(II->getArgOperand(1), Len); @@ -325,6 +328,7 @@ case Intrinsic::memcpy_element_unordered_atomic: case Intrinsic::memmove_element_unordered_atomic: case Intrinsic::memset_element_unordered_atomic: + case Intrinsic::masked_store: return true; } } @@ -370,9 +374,10 @@ } /// Return the pointer that is being written to. -static Value *getStoredPointerOperand(Instruction *I) { +static Value *getStoredPointerOperand(Instruction *I, + const TargetLibraryInfo &TLI) { //TODO: factor this to reuse getLocForWrite - MemoryLocation Loc = getLocForWrite(I); + MemoryLocation Loc = getLocForWrite(I, TLI); assert(Loc.Ptr && "unable to find pointer written for analyzable instruction?"); // TODO: most APIs don't expect const Value * @@ -487,6 +492,24 @@ return OW_MaybePartial; } +static OverwriteResult isMaskedStoreOverwrite(Instruction *Later, + Instruction *Earlier) { + auto *IIL = dyn_cast(Later); + auto *IIE = dyn_cast(Earlier); + if (IIL == nullptr || IIE == nullptr) + return OW_Unknown; + if (IIL->getIntrinsicID() != Intrinsic::masked_store || + IIE->getIntrinsicID() != Intrinsic::masked_store) + return OW_Unknown; + // Pointers. + if (IIL->getArgOperand(1) != IIE->getArgOperand(1)) + return OW_Unknown; + // Masks. + if (IIL->getArgOperand(3) != IIE->getArgOperand(3)) + return OW_Unknown; + return OW_Complete; +} + /// Return 'OW_Complete' if a store to the 'Later' location completely /// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the /// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the @@ -796,7 +819,7 @@ break; Value *DepPointer = - getUnderlyingObject(getStoredPointerOperand(Dependency)); + getUnderlyingObject(getStoredPointerOperand(Dependency, *TLI)); // Check for aliasing. if (!AA->isMustAlias(F->getArgOperand(0), DepPointer)) @@ -902,7 +925,7 @@ if (hasAnalyzableMemoryWrite(&*BBI, *TLI) && isRemovable(&*BBI)) { // See through pointer-to-pointer bitcasts SmallVector Pointers; - getUnderlyingObjects(getStoredPointerOperand(&*BBI), Pointers); + getUnderlyingObjects(getStoredPointerOperand(&*BBI, *TLI), Pointers); // Stores to stack values are valid candidates for removal. bool AllDead = true; @@ -1119,11 +1142,12 @@ } static bool removePartiallyOverlappedStores(const DataLayout &DL, - InstOverlapIntervalsTy &IOL) { + InstOverlapIntervalsTy &IOL, + const TargetLibraryInfo &TLI) { bool Changed = false; for (auto OI : IOL) { Instruction *EarlierWrite = OI.first; - MemoryLocation Loc = getLocForWrite(EarlierWrite); + MemoryLocation Loc = getLocForWrite(EarlierWrite, TLI); assert(isRemovable(EarlierWrite) && "Expect only removable instruction"); const Value *Ptr = Loc.Ptr->stripPointerCasts(); @@ -1284,7 +1308,7 @@ continue; // Figure out what location is being stored to. - MemoryLocation Loc = getLocForWrite(Inst); + MemoryLocation Loc = getLocForWrite(Inst, *TLI); // If we didn't get a useful location, fail. if (!Loc.Ptr) @@ -1308,7 +1332,7 @@ Instruction *DepWrite = InstDep.getInst(); if (!hasAnalyzableMemoryWrite(DepWrite, *TLI)) break; - MemoryLocation DepLoc = getLocForWrite(DepWrite); + MemoryLocation DepLoc = getLocForWrite(DepWrite, *TLI); // If we didn't get a useful location, or if it isn't a size, bail out. if (!DepLoc.Ptr) break; @@ -1352,6 +1376,11 @@ int64_t InstWriteOffset, DepWriteOffset; OverwriteResult OR = isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset, *AA, BB.getParent()); + if (OR == OW_Unknown) { + // isOverwrite punts on MemoryLocations with an imprecise size, such + // as masked stores. Handle this here, somwewhat inelegantly. + OR = isMaskedStoreOverwrite(Inst, DepWrite); + } if (OR == OW_MaybePartial) OR = isPartialOverwrite(Loc, DepLoc, DepWriteOffset, InstWriteOffset, DepWrite, IOL); @@ -1433,7 +1462,7 @@ } if (EnablePartialOverwriteTracking) - MadeChange |= removePartiallyOverlappedStores(DL, IOL); + MadeChange |= removePartiallyOverlappedStores(DL, IOL, *TLI); // If this block ends in a return, unwind, or unreachable, all allocas are // dead at its end, which means stores to them are also dead. @@ -2494,7 +2523,7 @@ if (EnablePartialOverwriteTracking) for (auto &KV : State.IOLs) - MadeChange |= removePartiallyOverlappedStores(State.DL, KV.second); + MadeChange |= removePartiallyOverlappedStores(State.DL, KV.second, TLI); MadeChange |= State.eliminateDeadWritesAtEndOfFunction(); return MadeChange; diff --git a/llvm/test/Transforms/DeadStoreElimination/masked-dead-store.ll b/llvm/test/Transforms/DeadStoreElimination/masked-dead-store.ll --- a/llvm/test/Transforms/DeadStoreElimination/masked-dead-store.ll +++ b/llvm/test/Transforms/DeadStoreElimination/masked-dead-store.ll @@ -9,26 +9,24 @@ ; CHECK-NEXT: [[V1:%.*]] = load i8*, i8** [[V0]], align 4, [[TBAA0:!tbaa !.*]] ; CHECK-NEXT: [[V2:%.*]] = getelementptr i8, i8* [[V1]], i32 [[A3:%.*]] ; CHECK-NEXT: [[V3:%.*]] = bitcast i8* [[V2]] to <128 x i8>* -; CHECK-NEXT: tail call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> , <128 x i8>* [[V3]], i32 32, <128 x i1> ), [[TBAA3:!tbaa !.*]] ; CHECK-NEXT: [[V6:%.*]] = getelementptr inbounds i8*, i8** [[A1:%.*]], i32 [[A4:%.*]] -; CHECK-NEXT: [[V7:%.*]] = load i8*, i8** [[V6]], align 4, [[TBAA6:!tbaa !.*]] +; CHECK-NEXT: [[V7:%.*]] = load i8*, i8** [[V6]], align 4, [[TBAA3:!tbaa !.*]] ; CHECK-NEXT: [[V8:%.*]] = getelementptr i8, i8* [[V7]], i32 [[A5:%.*]] ; CHECK-NEXT: [[V9:%.*]] = bitcast i8* [[V8]] to <128 x i8>* -; CHECK-NEXT: [[V10:%.*]] = tail call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[V9]], i32 32, <128 x i1> , <128 x i8> undef), [[TBAA8:!tbaa !.*]] +; CHECK-NEXT: [[V10:%.*]] = tail call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[V9]], i32 32, <128 x i1> , <128 x i8> undef), [[TBAA5:!tbaa !.*]] ; CHECK-NEXT: [[V11:%.*]] = shufflevector <128 x i8> [[V10]], <128 x i8> undef, <32 x i32> ; CHECK-NEXT: [[V14:%.*]] = shufflevector <32 x i8> [[V11]], <32 x i8> undef, <128 x i32> -; CHECK-NEXT: tail call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[V14]], <128 x i8>* [[V3]], i32 32, <128 x i1> ), [[TBAA3]] ; CHECK-NEXT: [[V16:%.*]] = shufflevector <128 x i8> [[V14]], <128 x i8> undef, <32 x i32> ; CHECK-NEXT: [[V17:%.*]] = getelementptr inbounds i8*, i8** [[A1]], i32 [[A6:%.*]] -; CHECK-NEXT: [[V18:%.*]] = load i8*, i8** [[V17]], align 4, [[TBAA6]] +; CHECK-NEXT: [[V18:%.*]] = load i8*, i8** [[V17]], align 4, [[TBAA3]] ; CHECK-NEXT: [[V19:%.*]] = getelementptr i8, i8* [[V18]], i32 [[A7:%.*]] ; CHECK-NEXT: [[V20:%.*]] = bitcast i8* [[V19]] to <128 x i8>* -; CHECK-NEXT: [[V21:%.*]] = tail call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[V20]], i32 32, <128 x i1> , <128 x i8> undef), [[TBAA8]] +; CHECK-NEXT: [[V21:%.*]] = tail call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[V20]], i32 32, <128 x i1> , <128 x i8> undef), [[TBAA5]] ; CHECK-NEXT: [[V22:%.*]] = shufflevector <128 x i8> [[V21]], <128 x i8> undef, <32 x i32> ; CHECK-NEXT: [[V23:%.*]] = icmp ugt <32 x i8> [[V16]], [[V22]] ; CHECK-NEXT: [[V24:%.*]] = select <32 x i1> [[V23]], <32 x i8> [[V16]], <32 x i8> [[V22]] ; CHECK-NEXT: [[V25:%.*]] = shufflevector <32 x i8> [[V24]], <32 x i8> undef, <128 x i32> -; CHECK-NEXT: tail call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[V25]], <128 x i8>* [[V3]], i32 32, <128 x i1> ), [[TBAA3]] +; CHECK-NEXT: tail call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[V25]], <128 x i8>* [[V3]], i32 32, <128 x i1> ), [[TBAA8:!tbaa !.*]] ; CHECK-NEXT: ret i32 0 ; b0: