Index: include/llvm/Transforms/Utils/BuildLibCalls.h =================================================================== --- include/llvm/Transforms/Utils/BuildLibCalls.h +++ include/llvm/Transforms/Utils/BuildLibCalls.h @@ -128,6 +128,16 @@ /// Emit a call to the calloc function. Value *emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs, IRBuilder<> &B, const TargetLibraryInfo &TLI); + + /// Emit a call to the strcat function to the builder, for the specified + /// pointer arguments. + Value *emitStrCat(Value *Dst, Value *Src, IRBuilder<> &B, + const TargetLibraryInfo *TLI, StringRef Name = "strcat"); + + /// Emit a call to the strncat function to the builder, for the specified + /// pointer arguments and length. + Value *emitStrNCat(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B, + const TargetLibraryInfo *TLI, StringRef Name = "strncat"); } #endif Index: lib/Transforms/Utils/BuildLibCalls.cpp =================================================================== --- lib/Transforms/Utils/BuildLibCalls.cpp +++ lib/Transforms/Utils/BuildLibCalls.cpp @@ -1074,3 +1074,36 @@ return CI; } + +Value *llvm::emitStrCat(Value *Dst, Value *Src, IRBuilder<> &B, + const TargetLibraryInfo *TLI, StringRef Name) { + if (!TLI->has(LibFunc_strcat)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + Type *I8Ptr = B.getInt8PtrTy(); + Value *StrCat = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr); + inferLibFuncAttributes(*M->getFunction(Name), *TLI); + CallInst *CI = + B.CreateCall(StrCat, {castToCStr(Dst, B), castToCStr(Src, B)}, Name); + if (const Function *F = dyn_cast<Function>(StrCat->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); + return CI; +} + +Value *llvm::emitStrNCat(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B, + const TargetLibraryInfo *TLI, StringRef Name) { + if (!TLI->has(LibFunc_strncat)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + Type *I8Ptr = B.getInt8PtrTy(); + Value *StrNCat = + M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, Len->getType()); + inferLibFuncAttributes(*M->getFunction(Name), *TLI); + CallInst *CI = B.CreateCall( + StrNCat, {castToCStr(Dst, B), castToCStr(Src, B), Len}, "strncat"); + if (const Function *F = dyn_cast<Function>(StrNCat->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); + return CI; +} \ No newline at end of file Index: lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- lib/Transforms/Utils/SimplifyLibCalls.cpp +++ lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1838,6 +1838,20 @@ if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) return nullptr; + // sprintf(buf, "%s%s", buf, str) -> strcat(buf, str) + if (FormatStr == "%s%s" && CI->getNumArgOperands() == 4) { + Value *Buf = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(2); + if (Buf == Src) { + Value *V = emitStrCat(Buf, CI->getArgOperand(3), B, TLI); + if (CI->use_empty()) + return V; + + Value *Len = emitStrLen(V, B, DL, TLI); + return B.CreateIntCast(Len, CI->getType(), false); + } + } + // If we just have a format string (nothing else crazy) transform it. if (CI->getNumArgOperands() == 2) { // Make sure there's no % in the constant array. We could try to handle @@ -1855,39 +1869,40 @@ // The remaining optimizations require the format string to be "%s" or "%c" // and have an extra operand. - if (FormatStr.size() != 2 || FormatStr[0] != '%' || - CI->getNumArgOperands() < 3) - return nullptr; + if (FormatStr.size() == 2 && FormatStr[0] == '%' && + CI->getNumArgOperands() == 3) { - // Decode the second character of the format string. - if (FormatStr[1] == 'c') { - // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 - if (!CI->getArgOperand(2)->getType()->isIntegerTy()) - return nullptr; - Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); - Value *Ptr = castToCStr(CI->getArgOperand(0), B); - B.CreateStore(V, Ptr); - Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); - B.CreateStore(B.getInt8(0), Ptr); + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 + if (!CI->getArgOperand(2)->getType()->isIntegerTy()) + return nullptr; + Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); + Value *Ptr = castToCStr(CI->getArgOperand(0), B); + B.CreateStore(V, Ptr); + Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); + B.CreateStore(B.getInt8(0), Ptr); - return ConstantInt::get(CI->getType(), 1); - } + return ConstantInt::get(CI->getType(), 1); + } - if (FormatStr[1] == 's') { - // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) - if (!CI->getArgOperand(2)->getType()->isPointerTy()) - return nullptr; + if (FormatStr[1] == 's') { + // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) + if (!CI->getArgOperand(2)->getType()->isPointerTy()) + return nullptr; - Value *Len = emitStrLen(CI->getArgOperand(2), B, DL, TLI); - if (!Len) - return nullptr; - Value *IncLen = - B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), "leninc"); - B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(2), 1, IncLen); + Value *Len = emitStrLen(CI->getArgOperand(2), B, DL, TLI); + if (!Len) + return nullptr; + Value *IncLen = + B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), "leninc"); + B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(2), 1, IncLen); - // The sprintf result is the unincremented number of bytes in the string. - return B.CreateIntCast(Len, CI->getType(), false); + // The sprintf result is the unincremented number of bytes in the string. + return B.CreateIntCast(Len, CI->getType(), false); + } } + return nullptr; } @@ -1918,6 +1933,21 @@ if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr)) return nullptr; + // snprintf(buf, n, "%s%s", buf, str) -> strncat(buf, str, n) + if (FormatStr == "%s%s" && CI->getNumArgOperands() == 5) { + Value *Buf = CI->getArgOperand(0); + Value *N = CI->getArgOperand(1); + Value *Src = CI->getArgOperand(3); + if (Buf == Src) { + Value *V = emitStrNCat(Buf, CI->getArgOperand(4), N, B, TLI); + if (CI->use_empty()) + return V; + + Value *Len = emitStrLen(V, B, DL, TLI); + return B.CreateIntCast(Len, CI->getType(), false); + } + } + // Check for size ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1)); if (!Size) Index: test/Transforms/InstCombine/concat.ll =================================================================== --- test/Transforms/InstCombine/concat.ll +++ test/Transforms/InstCombine/concat.ll @@ -0,0 +1,71 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instcombine -S | FileCheck %s + +@.str = private unnamed_addr constant [5 x i8] c"%s%s\00", align 1 + +declare i32 @sprintf(i8* nocapture, i8* nocapture readonly, ...) +declare i32 @snprintf(i8* nocapture, i32, i8* nocapture readonly, ...) +declare i8* @strcat(i8*, i8* nocapture readonly) +declare i32 @strlen(i8* nocapture) +declare i8* @strncat(i8*, i8* nocapture readonly, i32) + + +define i32 @sprintf_concat_ok_return(i8* %str, i8* %str2) { +; CHECK-LABEL: @sprintf_concat_ok_return( +; CHECK-NEXT: [[STRCAT:%.*]] = call i8* @strcat(i8* [[STR:%.*]], i8* [[STR2:%.*]]) +; CHECK-NEXT: [[STRLEN:%.*]] = call i64 bitcast (i32 (i8*)* @strlen to i64 (i8*)*)(i8* [[STRCAT]]) +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[STRLEN]] to i32 +; CHECK-NEXT: ret i32 [[TMP1]] +; + %call = tail call i32 (i8*, i8*, ...) @sprintf(i8* %str, i8* getelementptr inbounds ([5 x i8], [5 x i8]* @.str, i32 0, i32 0), i8* %str, i8* %str2) #2 + ret i32 %call +} + +define void @sprintf_concat_ok(i8* %str, i8* %str2) { +; CHECK-LABEL: @sprintf_concat_ok( +; CHECK-NEXT: [[STRCAT:%.*]] = call i8* @strcat(i8* [[STR:%.*]], i8* [[STR2:%.*]]) +; CHECK-NEXT: ret void +; + %call = tail call i32 (i8*, i8*, ...) @sprintf(i8* %str, i8* getelementptr inbounds ([5 x i8], [5 x i8]* @.str, i32 0, i32 0), i8* %str, i8* %str2) #2 + ret void +} + +define i32 @snprintf_concat_ok_return(i8* %str, i8* %str2, i32 %n) { +; CHECK-LABEL: @snprintf_concat_ok_return( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[STRNCAT:%.*]] = call i8* @strncat(i8* [[STR:%.*]], i8* [[STR2:%.*]], i32 [[N:%.*]]) +; CHECK-NEXT: [[STRLEN:%.*]] = call i64 bitcast (i32 (i8*)* @strlen to i64 (i8*)*)(i8* [[STRNCAT]]) +; CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[STRLEN]] to i32 +; CHECK-NEXT: ret i32 [[TMP0]] +; +entry: + %call = tail call i32 (i8*, i32, i8*, ...) @snprintf(i8* %str, i32 %n, i8* getelementptr inbounds ([5 x i8], [5 x i8]* @.str, i32 0, i32 0), i8* %str, i8* %str2) #2 + ret i32 %call +} + +define void @snprintf_concat_ok(i8* %str, i8* %str2, i32 %n) { +; CHECK-LABEL: @snprintf_concat_ok( +; CHECK-NEXT: [[STRNCAT:%.*]] = call i8* @strncat(i8* [[STR:%.*]], i8* [[STR2:%.*]], i32 [[N:%.*]]) +; CHECK-NEXT: ret void +; + %call = tail call i32 (i8*, i32, i8*, ...) @snprintf(i8* %str, i32 %n, i8* getelementptr inbounds ([5 x i8], [5 x i8]* @.str, i32 0, i32 0), i8* %str, i8* %str2) #2 + ret void +} + +define void @snprintf_const_fmt(i8* %str, i8* %str2, i32 %n, i8* nocapture readonly %fmt) { +; CHECK-LABEL: @snprintf_const_fmt( +; CHECK-NEXT: [[CALL:%.*]] = tail call i32 (i8*, i32, i8*, ...) @snprintf(i8* [[STR:%.*]], i32 [[N:%.*]], i8* [[FMT:%.*]], i8* [[STR]], i8* [[STR2:%.*]]) +; CHECK-NEXT: ret void +; + %call = tail call i32 (i8*, i32, i8*, ...) @snprintf(i8* %str, i32 %n, i8* %fmt, i8* %str, i8* %str2) #2 + ret void +} + +define void @sprintf_const_fmt(i8* %str, i8* %str2, i8* nocapture readonly %fmt) { +; CHECK-LABEL: @sprintf_const_fmt( +; CHECK-NEXT: [[CALL:%.*]] = tail call i32 (i8*, i8*, ...) @sprintf(i8* [[STR:%.*]], i8* [[FMT:%.*]], i8* [[STR]], i8* [[STR2:%.*]]) +; CHECK-NEXT: ret void +; + %call = tail call i32 (i8*, i8*, ...) @sprintf(i8* %str, i8* %fmt, i8* %str, i8* %str2) #2 + ret void +}