diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h --- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -132,7 +132,7 @@ eraseFromParent(I); } - Value *foldMallocMemset(CallInst *Memset, IRBuilderBase &B); + Value *foldAllocMemset(CallInst *Memset, IRBuilderBase &B); public: LibCallSimplifier( diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1155,59 +1155,127 @@ return CI->getArgOperand(0); } -/// Fold memset[_chk](malloc(n), 0, n) --> calloc(1, n). -Value *LibCallSimplifier::foldMallocMemset(CallInst *Memset, IRBuilderBase &B) { +/// Fold memset(malloc(n), 0, n) --> calloc(1, n). +/// Fold memset(calloc(n, m), 0, k) --> calloc(n, m). +Value *LibCallSimplifier::foldAllocMemset(CallInst *Memset, IRBuilderBase &B) { + if (!isa(Memset)) + return nullptr; + + if (auto *MemSetInstr = dyn_cast_or_null(Memset)) { + if (MemSetInstr->getVolatileCst() == + ConstantInt::getTrue(Memset->getContext())) + return nullptr; + + if (MemSetInstr->getRawDest()->getType()->getPointerAddressSpace() != 0) + return nullptr; + } // This has to be a memset of zeros (bzero). auto *FillValue = dyn_cast(Memset->getArgOperand(1)); if (!FillValue || FillValue->getZExtValue() != 0) return nullptr; - // TODO: We should handle the case where the malloc has more than one use. - // This is necessary to optimize common patterns such as when the result of - // the malloc is checked against null or when a memset intrinsic is used in - // place of a memset library call. - auto *Malloc = dyn_cast(Memset->getArgOperand(0)); - if (!Malloc || !Malloc->hasOneUse()) + auto canRemoveMemsetAfterAllocCall = [](CallInst *Alloc, CallInst *Memset) { + if (Alloc->hasOneUse()) + return true; + BasicBlock *AllocBB = Alloc->getParent(), *MemsetBB = Memset->getParent(); + // If alloc call and memset are in different blocks then + // make sure AllocBB ends matching to following sequence: + // %cmp = icmp _ %call, null + // br %cmp, _, memsetBB + if (AllocBB != MemsetBB) { + if (MemsetBB->getSinglePredecessor() != AllocBB) + return false; + Value *Op = Memset->getArgOperand(0); + Instruction *TI = AllocBB->getTerminator(); + ICmpInst::Predicate Pred; + BasicBlock *TrueBB, *FalseBB; + if (!match(TI, + m_Br(m_ICmp(Pred, + m_CombineOr(m_Specific(Op), + m_Specific(Op->stripPointerCasts())), + m_Zero()), + TrueBB, FalseBB))) + return false; + if (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE) + return false; + if (MemsetBB != (Pred == ICmpInst::ICMP_EQ ? FalseBB : TrueBB)) + return false; + } + // Conservatively search for writes to make sure alloc'd memory is not used + // before memset. We don't perform any sophisticated alias analysis here. + // Check AllocBB starting from alloc successor. + for (BasicBlock::iterator Scan = ++(Alloc->getIterator()), + End = AllocBB->end(); + Scan != End; ++Scan) { + if (&*Scan == Memset) + break; + if (Scan->mayWriteToMemory()) + return false; + } + // Check MemsetBB ending at memset. + if (AllocBB != MemsetBB) + for (BasicBlock::iterator Scan = MemsetBB->begin(), + End = Memset->getIterator(); + Scan != End; ++Scan) { + if (Scan->mayWriteToMemory()) + return false; + } + return true; + }; + + auto *Alloc = dyn_cast(Memset->getArgOperand(0)); + if (!Alloc) return nullptr; - // Is the inner call really malloc()? - Function *InnerCallee = Malloc->getCalledFunction(); + Function *InnerCallee = Alloc->getCalledFunction(); if (!InnerCallee) return nullptr; LibFunc Func; + // Is the inner call really malloc() or calloc()? if (!TLI->getLibFunc(*InnerCallee, Func) || !TLI->has(Func) || - Func != LibFunc_malloc) + (Func != LibFunc_malloc && Func != LibFunc_calloc)) return nullptr; // The memset must cover the same number of bytes that are malloc'd. - if (Memset->getArgOperand(2) != Malloc->getArgOperand(0)) + if (Func == LibFunc_malloc && + (Memset->getArgOperand(2) != Alloc->getArgOperand(0))) return nullptr; - // Replace the malloc with a calloc. We need the data layout to know what the - // actual size of a 'size_t' parameter is. - B.SetInsertPoint(Malloc->getParent(), ++Malloc->getIterator()); - const DataLayout &DL = Malloc->getModule()->getDataLayout(); + // We should handle the case where the malloc/calloc has more than one use. + // This is necessary to optimize common patterns such as when the result of + // the malloc/calloc is checked against null or when a memset intrinsic is + // used in place of a memset library call. + if (!canRemoveMemsetAfterAllocCall(Alloc, Memset)) + return nullptr; + + // Memset is redundant. Replace it with existing calloc. + if (Func == LibFunc_calloc) + return Memset->getArgOperand(0); + // Memset is redundant and malloc can be replaced with calloc. + // We need the data layout to know what the actual size of a 'size_t' + // parameter is. + B.SetInsertPoint(Alloc->getParent(), ++Alloc->getIterator()); + const DataLayout &DL = Alloc->getModule()->getDataLayout(); IntegerType *SizeType = DL.getIntPtrType(B.GetInsertBlock()->getContext()); - if (Value *Calloc = emitCalloc(ConstantInt::get(SizeType, 1), - Malloc->getArgOperand(0), - Malloc->getAttributes(), B, *TLI)) { - substituteInParent(Malloc, Calloc); + if (Value *Calloc = + emitCalloc(ConstantInt::get(SizeType, 1), Alloc->getArgOperand(0), + Alloc->getAttributes(), B, *TLI)) { + substituteInParent(Alloc, Calloc); return Calloc; } - return nullptr; } Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilderBase &B) { Value *Size = CI->getArgOperand(2); annotateNonNullAndDereferenceable(CI, 0, Size, DL); + if (auto *Calloc = foldAllocMemset(CI, B)) + return Calloc; + if (isa(CI)) return nullptr; - if (auto *Calloc = foldMallocMemset(CI, B)) - return Calloc; - // memset(p, v, n) -> llvm.memset(align 1 p, v, n) Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); CallInst *NewCI = B.CreateMemSet(CI->getArgOperand(0), Val, Size, Align(1)); @@ -3052,7 +3120,7 @@ return optimizeLog(CI, Builder); case Intrinsic::sqrt: return optimizeSqrt(CI, Builder); - // TODO: Use foldMallocMemset() with memset intrinsic. + // TODO: Use foldAllocMemset() with memset intrinsic. case Intrinsic::memset: return optimizeMemSet(CI, Builder); case Intrinsic::memcpy: @@ -3275,7 +3343,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, IRBuilderBase &B) { - // TODO: Try foldMallocMemset() here. + // TODO: Try foldAllocMemset() here. if (isFortifiedCallFoldable(CI, 3, 2)) { Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); diff --git a/llvm/test/Transforms/InstCombine/memset-1.ll b/llvm/test/Transforms/InstCombine/memset-1.ll --- a/llvm/test/Transforms/InstCombine/memset-1.ll +++ b/llvm/test/Transforms/InstCombine/memset-1.ll @@ -8,6 +8,7 @@ declare i8* @memset(i8*, i32, i32) declare void @llvm.memset.p0i8.i32(i8* nocapture writeonly, i8, i32, i32, i1) declare noalias i8* @malloc(i32) #1 +declare noalias i8* @calloc(i32, i32) #1 ; Check memset(mem1, val, size) -> llvm.memset(mem1, val, size, 1). @@ -31,12 +32,11 @@ ret i8* %call2 } -; FIXME: A memset intrinsic should be handled similarly to a memset() libcall. +; A memset intrinsic is handled similarly to a memset() libcall. -define i8* @malloc_and_memset_intrinsic(i32 %n) #0 { -; CHECK-LABEL: @malloc_and_memset_intrinsic( -; CHECK-NEXT: [[CALL:%.*]] = call i8* @malloc(i32 [[N:%.*]]) -; CHECK-NEXT: call void @llvm.memset.p0i8.i32(i8* align 1 [[CALL]], i8 0, i32 [[N]], i1 false) +define i8* @pr25892_lite2(i32 %n) #0 { +; CHECK-LABEL: @pr25892_lite2( +; CHECK-NEXT: [[CALL:%.*]] = call i8* @calloc(i32 1, i32 [[N:%.*]]) ; CHECK-NEXT: ret i8* [[CALL]] ; %call = call i8* @malloc(i32 %n) @@ -57,17 +57,47 @@ ret i8* %call2 } -; FIXME: memset(malloc(x), 0, x) -> calloc(1, x) -; This doesn't fire currently because the malloc has more than one use. +; memset(malloc(x), 0, x) -> calloc(1, x) define float* @pr25892(i32 %size) #0 { ; CHECK-LABEL: @pr25892( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[CALL:%.*]] = call i8* @calloc(i32 1, i32 [[SIZE:%.*]]) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[CALL]], null +; CHECK-NEXT: br i1 [[CMP]], label [[CLEANUP:%.*]], label [[IF_END:%.*]] +; CHECK: if.end: +; CHECK-NEXT: [[BC:%.*]] = bitcast i8* [[CALL]] to float* +; CHECK-NEXT: br label [[CLEANUP]] +; CHECK: cleanup: +; CHECK-NEXT: [[RETVAL_0:%.*]] = phi float* [ [[BC]], [[IF_END]] ], [ null, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret float* [[RETVAL_0]] +; +entry: + %call = tail call i8* @malloc(i32 %size) #1 + %cmp = icmp eq i8* %call, null + br i1 %cmp, label %cleanup, label %if.end +if.end: + %bc = bitcast i8* %call to float* + %call2 = tail call i8* @memset(i8* nonnull %call, i32 0, i32 %size) #1 + br label %cleanup +cleanup: + %retval.0 = phi float* [ %bc, %if.end ], [ null, %entry ] + ret float* %retval.0 +} + +declare void @clobber_memory(float*) + +; Don't perform PR25892 optimization when clobbering + +define float* @pr25892_with_clobbering(i32 %size) #0 { +; CHECK-LABEL: @pr25892_with_clobbering( +; CHECK-NEXT: entry: ; CHECK-NEXT: [[CALL:%.*]] = tail call i8* @malloc(i32 [[SIZE:%.*]]) [[ATTR0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[CALL]], null ; CHECK-NEXT: br i1 [[CMP]], label [[CLEANUP:%.*]], label [[IF_END:%.*]] ; CHECK: if.end: ; CHECK-NEXT: [[BC:%.*]] = bitcast i8* [[CALL]] to float* +; CHECK-NEXT: tail call void @clobber_memory(float* nonnull [[BC]]) ; CHECK-NEXT: call void @llvm.memset.p0i8.i32(i8* nonnull align 1 [[CALL]], i8 0, i32 [[SIZE]], i1 false) [[ATTR0]] ; CHECK-NEXT: br label [[CLEANUP]] ; CHECK: cleanup: @@ -80,6 +110,7 @@ br i1 %cmp, label %cleanup, label %if.end if.end: %bc = bitcast i8* %call to float* + tail call void @clobber_memory(float* %bc) %call2 = tail call i8* @memset(i8* nonnull %call, i32 0, i32 %size) #1 br label %cleanup cleanup: @@ -102,6 +133,52 @@ ret i8* %memset } +; memset(calloc(n, m), 0, n * m) --> calloc(n, m) + +define i8* @test_remove_memset1() { +; CHECK-LABEL: @test_remove_memset1( +; CHECK-NEXT: [[PTR:%.*]] = call dereferenceable_or_null(24) i8* @calloc(i32 1, i32 24) +; CHECK-NEXT: ret i8* [[PTR]] +; + %1 = call i8* @calloc(i32 1, i32 24) #1 + %2 = call i8* @memset(i8* nonnull %1, i32 0, i32 24) + ret i8* %2 +} + +; memset(calloc(n, m), 0, k) --> calloc(n, m) + +define i8* @test_remove_memset2() { +; CHECK-LABEL: @test_remove_memset2( +; CHECK-NEXT: [[PTR:%.*]] = call dereferenceable_or_null(32) i8* @calloc(i32 8, i32 4) +; CHECK-NEXT: ret i8* [[PTR]] + %1 = alloca float*, align 8 + %2 = call i8* @calloc(i32 8, i32 4) #1 + %3 = bitcast i8* %2 to float* + store float* %3, float** %1, align 8 + %4 = load float*, float** %1, align 8 + %5 = bitcast float* %4 to i8* + %6 = call i8* @memset(i8* nonnull %5, i32 0, i32 24) + ret i8* %6 +} + +; 1. memset(malloc(x), 0, x) -> calloc(1, x) +; 2. memset(calloc(1, x), 0, x) --> calloc(1, x) + +define i8* @test_remove_memsets_and_emit_calloc() { +; CHECK-LABEL: @test_remove_memsets_and_emit_calloc( +; CHECK-NEXT: [[PTR:%.*]] = call dereferenceable_or_null(42) i8* @calloc(i32 1, i32 42) +; CHECK-NEXT: ret i8* [[PTR]] + %1 = alloca i8*, align 8 + %2 = call i8* @malloc(i32 42) #1 + store i8* %2, i8** %1, align 8 + %3 = load i8*, i8** %1, align 8 + %m1 = call i8* @memset(i8* nonnull %3, i32 0, i32 42) + %4 = load i8*, i8** %1, align 8 + %m2 = call i8* @memset(i8* nonnull %4, i32 0, i32 42) + %5 = load i8*, i8** %1, align 8 + ret i8* %5 +} + define i8* @memset_size_select(i1 %b, i8* %ptr) { ; CHECK-LABEL: @memset_size_select( ; CHECK-NEXT: [[SIZE:%.*]] = select i1 [[B:%.*]], i32 10, i32 50