Index: include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -159,6 +159,7 @@ int StreamArg = -1); Value *optimizePrintF(CallInst *CI, IRBuilder<> &B); Value *optimizeSPrintF(CallInst *CI, IRBuilder<> &B); + Value *optimizeSnPrintF(CallInst *CI, IRBuilder<> &B); Value *optimizeFPrintF(CallInst *CI, IRBuilder<> &B); Value *optimizeFWrite(CallInst *CI, IRBuilder<> &B); Value *optimizeFPuts(CallInst *CI, IRBuilder<> &B); @@ -172,6 +173,7 @@ SmallVectorImpl &SinCosCalls); Value *optimizePrintFString(CallInst *CI, IRBuilder<> &B); Value *optimizeSPrintFString(CallInst *CI, IRBuilder<> &B); + Value *optimizeSnPrintFString(CallInst *CI, IRBuilder<> &B); Value *optimizeFPrintFString(CallInst *CI, IRBuilder<> &B); /// hasFloatVersion - Checks if there is a float version of the specified Index: lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- lib/Transforms/Utils/SimplifyLibCalls.cpp +++ lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1912,6 +1912,94 @@ return nullptr; } +Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, IRBuilder<> &B) { + // Check for a fixed format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr)) + return nullptr; + + // Check for size + ConstantInt *Size = dyn_cast(CI->getArgOperand(1)); + if (!Size) + return nullptr; + + uint64_t N = Size->getZExtValue(); + + // If we just have a format string (nothing else crazy) transform it. + if (CI->getNumArgOperands() == 3) { + // Make sure there's no % in the constant array. We could try to handle + // %% -> % in the future if we cared. + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') + return nullptr; // we found a format specifier, bail out. + + if (N == 0) + return ConstantInt::get(CI->getType(), FormatStr.size()); + else if (N < FormatStr.size() + 1) + return nullptr; + + // sprintf(str, size, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, + // strlen(fmt)+1) + B.CreateMemCpy( + CI->getArgOperand(0), 1, CI->getArgOperand(2), 1, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), + FormatStr.size() + 1)); // Copy the null byte. + return ConstantInt::get(CI->getType(), FormatStr.size()); + } + + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() == 2 && FormatStr[0] == '%' && + CI->getNumArgOperands() == 4) { + + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + if (N == 0) + return ConstantInt::get(CI->getType(), 1); + else if (N == 1) + return nullptr; + + // snprintf(dst, size, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 + if (!CI->getArgOperand(3)->getType()->isIntegerTy()) + return nullptr; + Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char"); + Value *Ptr = castToCStr(CI->getArgOperand(0), B); + B.CreateStore(V, Ptr); + Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); + B.CreateStore(B.getInt8(0), Ptr); + + return ConstantInt::get(CI->getType(), 1); + } + + if (FormatStr[1] == 's') { + // snprintf(dest, size, "%s", str) to llvm.memcpy(dest, str, len+1, 1) + StringRef Str; + if (!getConstantStringInfo(CI->getArgOperand(3), Str)) + return nullptr; + + if (N == 0) + return ConstantInt::get(CI->getType(), Str.size()); + else if (N < Str.size() + 1) + return nullptr; + + B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(2), 1, + ConstantInt::get(CI->getType(), Str.size() + 1)); + + // The snprintf result is the unincremented number of bytes in the string. + return ConstantInt::get(CI->getType(), Str.size()); + } + } + return nullptr; +} + +Value *LibCallSimplifier::optimizeSnPrintF(CallInst *CI, IRBuilder<> &B) { + if (Value *V = optimizeSnPrintFString(CI, B)) { + return V; + } + + return nullptr; +} + Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, IRBuilder<> &B) { optimizeErrorReporting(CI, B, 0); @@ -2318,6 +2406,8 @@ return optimizePrintF(CI, Builder); case LibFunc_sprintf: return optimizeSPrintF(CI, Builder); + case LibFunc_snprintf: + return optimizeSnPrintF(CI, Builder); case LibFunc_fprintf: return optimizeFPrintF(CI, Builder); case LibFunc_fwrite: Index: test/Transforms/InstCombine/snprintf.ll =================================================================== --- test/Transforms/InstCombine/snprintf.ll +++ test/Transforms/InstCombine/snprintf.ll @@ -0,0 +1,131 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instcombine -S | FileCheck %s + +@.str = private unnamed_addr constant [4 x i8] c"str\00", align 1 +@.str.1 = private unnamed_addr constant [3 x i8] c"%%\00", align 1 +@.str.2 = private unnamed_addr constant [3 x i8] c"%c\00", align 1 +@.str.3 = private unnamed_addr constant [3 x i8] c"%s\00", align 1 + + +define dso_local void @test_not_const_fmt(i8* %buf, i8* %fmt) #0 { +; CHECK-LABEL: @test_not_const_fmt( +; CHECK-NEXT: [[CALL:%.*]] = call i32 (i8*, i64, i8*, ...) @snprintf(i8* [[BUF:%.*]], i64 32, i8* [[FMT:%.*]]) +; CHECK-NEXT: ret void +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 32, i8* %fmt) #2 + ret void +} + + +declare dso_local i32 @snprintf(i8*, i64, i8*, ...) #1 + + +define dso_local void @test_not_const_size(i8* %buf, i64 %size) #0 { +; CHECK-LABEL: @test_not_const_size( +; CHECK-NEXT: [[CALL:%.*]] = call i32 (i8*, i64, i8*, ...) @snprintf(i8* [[BUF:%.*]], i64 [[SIZE:%.*]], i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0)) +; CHECK-NEXT: ret void +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 %size, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0)) #2 + ret void +} + + +define dso_local i32 @test_return_value(i8* %buf) #0 { +; CHECK-LABEL: @test_return_value( +; CHECK-NEXT: ret i32 3 +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 0, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0)) #2 + ret i32 %call +} + +define dso_local void @test_percentage(i8* %buf) #0 { +; CHECK-LABEL: @test_percentage( +; CHECK-NEXT: [[CALL:%.*]] = call i32 (i8*, i64, i8*, ...) @snprintf(i8* [[BUF:%.*]], i64 32, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.1, i64 0, i64 0)) +; CHECK-NEXT: ret void +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 32, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.1, i64 0, i64 0)) #2 + ret void +} + +define dso_local i32 @test_null_buf_return_value() #0 { +; CHECK-LABEL: @test_null_buf_return_value( +; CHECK-NEXT: ret i32 3 +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* null, i64 0, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0)) #2 + ret i32 %call +} + +define dso_local i32 @test_percentage_return_value() #0 { +; CHECK-LABEL: @test_percentage_return_value( +; CHECK-NEXT: [[CALL:%.*]] = call i32 (i8*, i64, i8*, ...) @snprintf(i8* null, i64 0, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.1, i64 0, i64 0)) +; CHECK-NEXT: ret i32 [[CALL]] +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* null, i64 0, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.1, i64 0, i64 0)) #3 + ret i32 %call +} + + +define dso_local void @test_correct_copy(i8* %buf) #0 { +; CHECK-LABEL: @test_correct_copy( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[BUF:%.*]] to i32* +; CHECK-NEXT: store i32 7500915, i32* [[TMP1]], align 1 +; CHECK-NEXT: ret void +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 32, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0)) #2 + ret void +} + +define dso_local i32 @test_char_zero_size(i8* %buf) #0 { +; CHECK-LABEL: @test_char_zero_size( +; CHECK-NEXT: ret i32 1 +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 0, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.2, i64 0, i64 0), i32 65) #2 + ret i32 %call +} + +define dso_local i32 @test_char_wrong_size(i8* %buf) #0 { +; CHECK-LABEL: @test_char_wrong_size( +; CHECK-NEXT: [[CALL:%.*]] = call i32 (i8*, i64, i8*, ...) @snprintf(i8* [[BUF:%.*]], i64 1, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.2, i64 0, i64 0), i32 65) +; CHECK-NEXT: ret i32 [[CALL]] +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 1, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.2, i64 0, i64 0), i32 65) #2 + ret i32 %call +} + +define dso_local i32 @test_char_ok_size(i8* %buf) #0 { +; CHECK-LABEL: @test_char_ok_size( +; CHECK-NEXT: store i8 65, i8* [[BUF:%.*]], align 1 +; CHECK-NEXT: [[NUL:%.*]] = getelementptr i8, i8* [[BUF]], i64 1 +; CHECK-NEXT: store i8 0, i8* [[NUL]], align 1 +; CHECK-NEXT: ret i32 1 +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 32, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.2, i64 0, i64 0), i32 65) #2 + ret i32 %call +} + +define dso_local i32 @test_str_zero_size(i8* %buf) #0 { +; CHECK-LABEL: @test_str_zero_size( +; CHECK-NEXT: ret i32 3 +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 0, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.3, i64 0, i64 0), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0)) #2 + ret i32 %call +} + +define dso_local i32 @test_str_wrong_size(i8* %buf) #0 { +; CHECK-LABEL: @test_str_wrong_size( +; CHECK-NEXT: [[CALL:%.*]] = call i32 (i8*, i64, i8*, ...) @snprintf(i8* [[BUF:%.*]], i64 1, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.3, i64 0, i64 0), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0)) +; CHECK-NEXT: ret i32 [[CALL]] +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 1, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.3, i64 0, i64 0), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0)) #2 + ret i32 %call +} + +define dso_local i32 @test_str_ok_size(i8* %buf) #0 { +; CHECK-LABEL: @test_str_ok_size( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[BUF:%.*]] to i32* +; CHECK-NEXT: store i32 29477, i32* [[TMP1]], align 1 +; CHECK-NEXT: ret i32 3 +; + %call = call i32 (i8*, i64, i8*, ...) @snprintf(i8* %buf, i64 32, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str.3, i64 0, i64 0), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0)) #2 + ret i32 %call +}