Index: lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- lib/Transforms/Utils/SimplifyLibCalls.cpp +++ lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -2402,12 +2402,9 @@ } Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, IRBuilder<> &B) { - // Check for size - ConstantInt *Size = dyn_cast(CI->getArgOperand(1)); - if (!Size) - return nullptr; - - uint64_t N = Size->getZExtValue(); + Value *Dst = CI->getArgOperand(0); + Value *SizeArg = CI->getArgOperand(1); + ConstantInt *Size = dyn_cast(SizeArg); // Check for a fixed format string. StringRef FormatStr; if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr)) @@ -2420,6 +2417,9 @@ if (FormatStr.find('%') != StringRef::npos) return nullptr; // we found a format specifier, bail out. + if (!Size) + return nullptr; + uint64_t N = Size->getZExtValue(); if (N == 0) return ConstantInt::get(CI->getType(), FormatStr.size()); else if (N < FormatStr.size() + 1) @@ -2428,7 +2428,7 @@ // snprintf(dst, size, fmt) -> llvm.memcpy(align 1 dst, align 1 fmt, // strlen(fmt)+1) B.CreateMemCpy( - CI->getArgOperand(0), 1, CI->getArgOperand(2), 1, + Dst, 1, CI->getArgOperand(2), 1, ConstantInt::get(DL.getIntPtrType(CI->getContext()), FormatStr.size() + 1)); // Copy the null byte. return ConstantInt::get(CI->getType(), FormatStr.size()); @@ -2441,6 +2441,9 @@ // Decode the second character of the format string. if (FormatStr[1] == 'c') { + if (!Size) + return nullptr; + uint64_t N = Size->getZExtValue(); if (N == 0) return ConstantInt::get(CI->getType(), 1); else if (N == 1) @@ -2450,7 +2453,7 @@ if (!CI->getArgOperand(3)->getType()->isIntegerTy()) return nullptr; Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char"); - Value *Ptr = castToCStr(CI->getArgOperand(0), B); + Value *Ptr = castToCStr(Dst, B); B.CreateStore(V, Ptr); Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); B.CreateStore(B.getInt8(0), Ptr); @@ -2461,15 +2464,29 @@ if (FormatStr[1] == 's') { // snprintf(dest, size, "%s", str) to llvm.memcpy(dest, str, len+1, 1) StringRef Str; - if (!getConstantStringInfo(CI->getArgOperand(3), Str)) + if (!getConstantStringInfo(CI->getArgOperand(3), Str)) { + if (CI->use_empty() && isKnownNonZero(SizeArg, DL)) { + // snprintf (d, size, "%s", s) -> memccpy (d, s, '\0', size - 1), + // d[size - 1] = 0 + Value *DecreasedSize = B.CreateSub(SizeArg, B.getInt64(1)); + emitMemCCpy(Dst, CI->getArgOperand(3), B.getInt32('\0'), + DecreasedSize, B, TLI); + Value *DstEnd = B.CreateGEP(B.getInt8Ty(), Dst, DecreasedSize); + B.CreateStore(B.getInt8(0), DstEnd); + return Dst; + } return nullptr; + } + if (!Size) + return nullptr; + uint64_t N = Size->getZExtValue(); if (N == 0) return ConstantInt::get(CI->getType(), Str.size()); else if (N < Str.size() + 1) return nullptr; - B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(3), 1, + B.CreateMemCpy(Dst, 1, CI->getArgOperand(3), 1, ConstantInt::get(CI->getType(), Str.size() + 1)); // The snprintf result is the unincremented number of bytes in the string. Index: test/Transforms/InstCombine/snprintf-memccpy.ll =================================================================== --- test/Transforms/InstCombine/snprintf-memccpy.ll +++ test/Transforms/InstCombine/snprintf-memccpy.ll @@ -0,0 +1,60 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instcombine -S | FileCheck %s + +@.str = private constant [3 x i8] c"%s\00", align 1 +declare i32 @snprintf(i8*, i64, i8*, ...) + +define void @test_string_to_buf_retval_nonzero_n(i8* %buf, i8* %str) { +; CHECK-LABEL: @test_string_to_buf_retval_nonzero_n( +; CHECK-NEXT: [[MEMCCPY:%.*]] = call i8* @memccpy(i8* [[BUF:%.*]], i8* [[STR:%.*]], i32 0, i64 7) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, i8* [[BUF]], i64 7 +; CHECK-NEXT: store i8 0, i8* [[TMP1]], align 1 +; CHECK-NEXT: ret void +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 8, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), i8* %str) + ret void +} + + +define void @test_string_to_buf_retval_known_nonzero_n(i8* %buf, i64 %n, i8* %str) { +; CHECK-LABEL: @test_string_to_buf_retval_known_nonzero_n( +; CHECK-NEXT: [[SIZE:%.*]] = shl i64 3, [[N:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[SIZE]], -1 +; CHECK-NEXT: [[MEMCCPY:%.*]] = call i8* @memccpy(i8* [[BUF:%.*]], i8* [[STR:%.*]], i32 0, i64 [[TMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, i8* [[BUF]], i64 [[TMP1]] +; CHECK-NEXT: store i8 0, i8* [[TMP2]], align 1 +; CHECK-NEXT: ret void +; + %size = shl i64 3, %n + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 %size, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), i8* %str) + ret void +} + +; Negative tests + +define i32 @test_string_to_buf_retval_used_n_maybe_zero(i8* %buf, i64 %n, i8* %str) { +; CHECK-LABEL: @test_string_to_buf_retval_used_n_maybe_zero( +; CHECK-NEXT: [[CALL:%.*]] = call i32 (i8*, i64, i8*, ...) @snprintf(i8* [[BUF:%.*]], i64 [[N:%.*]], i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), i8* [[STR:%.*]]) +; CHECK-NEXT: ret i32 [[CALL]] +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 %n, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), i8* %str) + ret i32 %call +} + +define void @test_string_to_buf_retval_used_zero_n(i8* %buf, i8* %str) { +; CHECK-LABEL: @test_string_to_buf_retval_used_zero_n( +; CHECK-NEXT: [[CALL:%.*]] = call i32 (i8*, i64, i8*, ...) @snprintf(i8* [[BUF:%.*]], i64 0, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), i8* [[STR:%.*]]) +; CHECK-NEXT: ret void +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 0, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), i8* %str) + ret void +} + +define void @test_string_to_buf_retval_unused(i8* %buf, i64 %n, i8* %str) { +; CHECK-LABEL: @test_string_to_buf_retval_unused( +; CHECK-NEXT: [[CALL:%.*]] = call i32 (i8*, i64, i8*, ...) @snprintf(i8* [[BUF:%.*]], i64 [[N:%.*]], i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), i8* [[STR:%.*]]) +; CHECK-NEXT: ret void +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 %n, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), i8* %str) + ret void +}