diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp --- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -847,27 +847,26 @@ if (isSafeToPromote) { if (StructType *STy = dyn_cast(ByValTy)) { if (MaxElements > 0 && STy->getNumElements() > MaxElements) { - LLVM_DEBUG(dbgs() << "argpromotion disable promoting argument '" - << PtrArg->getName() + LLVM_DEBUG(dbgs() << "ArgPromotion disables passing the elements of" + << " the argument '" << PtrArg->getName() << "' because it would require adding more" << " than " << MaxElements << " arguments to the function.\n"); - continue; - } - - SmallVector Types; - append_range(Types, STy->elements()); - - // If all the elements are single-value types, we can promote it. - bool AllSimple = - all_of(Types, [](Type *Ty) { return Ty->isSingleValueType(); }); - - // Safe to transform, don't even bother trying to "promote" it. - // Passing the elements as a scalar will allow sroa to hack on - // the new alloca we introduce. - if (AllSimple && areTypesABICompatible(Types, *F, TTI)) { - ByValArgsToTransform.insert(PtrArg); - continue; + } else { + SmallVector Types; + append_range(Types, STy->elements()); + + // If all the elements are single-value types, we can promote it. + bool AllSimple = + all_of(Types, [](Type *Ty) { return Ty->isSingleValueType(); }); + + // Safe to transform, don't even bother trying to "promote" it. + // Passing the elements as a scalar will allow sroa to hack on + // the new alloca we introduce. + if (AllSimple && areTypesABICompatible(Types, *F, TTI)) { + ByValArgsToTransform.insert(PtrArg); + continue; + } } } } diff --git a/llvm/test/Transforms/ArgumentPromotion/byval-through-pointer-promotion.ll b/llvm/test/Transforms/ArgumentPromotion/byval-through-pointer-promotion.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/ArgumentPromotion/byval-through-pointer-promotion.ll @@ -0,0 +1,51 @@ +; RUN: opt -sroa -argpromotion -S %s | FileCheck %s + +%struct.A = type { float, [12 x i8], i64, [8 x i8] } + +define internal float @callee(%struct.A* byval(%struct.A) align 32 %0) { +; CHECK-LABEL: define {{[^@]+}}@callee +; CHECK-SAME: (float [[ARG_0:%.*]], i64 [[ARG_1:%.*]]) { +; CHECK-NEXT: [[SUM:%.*]] = fadd float 0.000000e+00, [[ARG_0]] +; CHECK-NEXT: [[COEFF:%.*]] = uitofp i64 [[ARG_1]] to float +; CHECK-NEXT: [[RES:%.*]] = fmul float [[SUM]], [[COEFF]] +; CHECK-NEXT: ret float [[RES]] +; + %2 = alloca float, align 4 + store float 0.000000e+00, float* %2, align 4 + %3 = load float, float* %2, align 4 + %4 = getelementptr inbounds %struct.A, %struct.A* %0, i32 0, i32 0 + %5 = load float, float* %4, align 32 + %6 = fadd float %3, %5 + %7 = getelementptr inbounds %struct.A, %struct.A* %0, i32 0, i32 2 + %8 = load i64, i64* %7, align 16 + %9 = uitofp i64 %8 to float + %10 = fmul float %6, %9 + ret float %10 +} + +define float @caller(float %0) { +; CHECK-LABEL: define {{[^@]+}}@caller +; CHECK-SAME: (float [[ARG_0:%.*]]) { +; CHECK-NEXT: [[TMP_0:%.*]] = alloca %struct.A, align 32 +; CHECK-NEXT: [[FL_PTR_0:%.*]] = getelementptr inbounds %struct.A, %struct.A* [[TMP_0]], i32 0, i32 0 +; CHECK-NEXT: store float [[ARG_0]], float* [[FL_PTR_0]], align 32 +; CHECK-NEXT: [[I64_PTR_0:%.*]] = getelementptr inbounds %struct.A, %struct.A* [[TMP_0]], i32 0, i32 2 +; CHECK-NEXT: store i64 2, i64* [[I64_PTR_0]], align 16 +; CHECK-NEXT: [[FL_PTR_1:%.*]] = getelementptr %struct.A, %struct.A* [[TMP_0]], i64 0, i32 0 +; CHECK-NEXT: [[FL_VAL:%.*]] = load float, float* [[FL_PTR_1]], align 32 +; CHECK-NEXT: [[I64_PTR_1:%.*]] = getelementptr %struct.A, %struct.A* [[TMP_0]], i64 0, i32 2 +; CHECK-NEXT: [[I64_VAL:%.*]] = load i64, i64* [[I64_PTR_1]], align 16 +; CHECK-NEXT: [[RES:%.*]] = call noundef float @callee(float [[FL_VAL]], i64 [[I64_VAL]]) +; CHECK-NEXT: ret float [[RES]] +; + %2 = alloca float, align 4 + %3 = alloca %struct.A, align 32 + store float %0, float* %2, align 4 + %4 = getelementptr inbounds %struct.A, %struct.A* %3, i32 0, i32 0 + %5 = load float, float* %2, align 4 + store float %5, float* %4, align 32 + %6 = getelementptr inbounds %struct.A, %struct.A* %3, i32 0, i32 2 + store i64 2, i64* %6, align 16 + %7 = call noundef float @callee(%struct.A* byval(%struct.A) align 32 %3) + ret float %7 +}