diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -1951,9 +1951,11 @@ // The inliner does not know how to inline through calls with operand bundles // in general ... + Value *ConvergenceControlToken = nullptr; if (CB.hasOperandBundles()) { for (int i = 0, e = CB.getNumOperandBundles(); i != e; ++i) { - uint32_t Tag = CB.getOperandBundleAt(i).getTagID(); + auto OBUse = CB.getOperandBundleAt(i); + uint32_t Tag = OBUse.getTagID(); // ... but it knows how to inline through "deopt" operand bundles ... if (Tag == LLVMContext::OB_deopt) continue; @@ -1964,11 +1966,37 @@ continue; if (Tag == LLVMContext::OB_kcfi) continue; + if (Tag == LLVMContext::OB_convergencectrl) { + ConvergenceControlToken = OBUse.Inputs[0].get(); + continue; + } return InlineResult::failure("unsupported operand bundle"); } } + // FIXME: The check below is redundant and incomplete. According to spec, if a + // convergent call is missing a token, then the caller is using uncontrolled + // convergence. If the callee has an entry intrinsic, then the callee is using + // controlled convergence, and the call cannot be inlined. A proper + // implemenation of this check requires a whole new analysis that identifies + // convergence in every function. For now, we skip that and just do this one + // cursory check. The underlying assumption is that in a compiler flow that + // fully implements convergence control tokens, there is no mixing of + // controlled and uncontrolled convergent operations in the whole program. + if (CB.isConvergent()) { + auto *I = CalledFunc->getEntryBlock().getFirstNonPHI(); + if (auto *IntrinsicCall = dyn_cast(I)) { + if (IntrinsicCall->getIntrinsicID() == + Intrinsic::experimental_convergence_entry) { + if (!ConvergenceControlToken) { + return InlineResult::failure( + "convergent call needs convergencectrl operand"); + } + } + } + } + // If the call to the callee cannot throw, set the 'nounwind' flag on any // calls that we inline. bool MarkNoUnwind = CB.doesNotThrow(); @@ -2258,6 +2286,17 @@ IFI.GetAssumptionCache(*Caller).registerAssumption(II); } + if (ConvergenceControlToken) { + auto *I = FirstNewBlock->getFirstNonPHI(); + if (auto *IntrinsicCall = dyn_cast(I)) { + if (IntrinsicCall->getIntrinsicID() == + Intrinsic::experimental_convergence_entry) { + IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken); + IntrinsicCall->eraseFromParent(); + } + } + } + // If there are any alloca instructions in the block that used to be the entry // block for the callee, move them to the entry block of the caller. First // calculate which instruction they should be inserted before. We insert the diff --git a/llvm/test/Transforms/Inline/convergence-inline.ll b/llvm/test/Transforms/Inline/convergence-inline.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Inline/convergence-inline.ll @@ -0,0 +1,193 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='cgscc(inline)' -S %s | FileCheck %s + +define void @nonconvergent_callee() alwaysinline { +; CHECK-LABEL: @nonconvergent_callee( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: call void @f(i32 0) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; +entry: + %token = call token @llvm.experimental.convergence.anchor() + call void @f(i32 0) [ "convergencectrl"(token %token) ] + ret void +} + +define void @convergent_callee(i32 %v) convergent alwaysinline { +; CHECK-LABEL: @convergent_callee( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: call void @f(i32 [[V:%.*]]) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; +entry: + %token = call token @llvm.experimental.convergence.entry() + call void @f(i32 %v) [ "convergencectrl"(token %token) ] + ret void +} + +define void @test_nonconvergent() { +; CHECK-LABEL: @test_nonconvergent( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TOKEN_I:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: call void @f(i32 0) [ "convergencectrl"(token [[TOKEN_I]]) ] +; CHECK-NEXT: ret void +; +entry: + call void @nonconvergent_callee() + ret void +} + +define void @test_convergent_basic(i1 %cond) { +; CHECK-LABEL: @test_convergent_basic( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: br i1 [[COND:%.*]], label [[THEN:%.*]], label [[END:%.*]] +; CHECK: then: +; CHECK-NEXT: call void @f(i32 0) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: br label [[END]] +; CHECK: end: +; CHECK-NEXT: ret void +; +entry: + %token = call token @llvm.experimental.convergence.anchor() + br i1 %cond, label %then, label %end + +then: + call void @convergent_callee(i32 0) [ "convergencectrl"(token %token) ] + br label %end + +end: + ret void +} + +define void @test_convergent_no_token(i1 %cond) convergent { +; CHECK-LABEL: @test_convergent_no_token( +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @convergent_callee(i32 0) +; CHECK-NEXT: ret void +; +entry: + call void @convergent_callee(i32 0) + ret void +} + +define void @test_convergent_multiple() convergent { +; CHECK-LABEL: @test_convergent_multiple( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: call void @f(i32 0) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: call void @f(i32 1) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: call void @f(i32 2) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; +entry: + %token = call token @llvm.experimental.convergence.entry() + call void @convergent_callee(i32 0) [ "convergencectrl"(token %token) ] + call void @convergent_callee(i32 1) [ "convergencectrl"(token %token) ] + call void @convergent_callee(i32 2) [ "convergencectrl"(token %token) ] + ret void +} + +define void @test_convergent_loop(i1 %cond) { +; CHECK-LABEL: @test_convergent_loop( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: br i1 [[COND:%.*]], label [[HDR:%.*]], label [[END:%.*]] +; CHECK: hdr: +; CHECK-NEXT: [[TOK_LOOP:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: call void @f(i32 0) [ "convergencectrl"(token [[TOK_LOOP]]) ] +; CHECK-NEXT: br i1 [[COND]], label [[HDR]], label [[END]] +; CHECK: end: +; CHECK-NEXT: ret void +; +entry: + %token = call token @llvm.experimental.convergence.anchor() + br i1 %cond, label %hdr, label %end + +hdr: + %tok.loop = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %token) ] + call void @convergent_callee(i32 0) [ "convergencectrl"(token %tok.loop) ] + br i1 %cond, label %hdr, label %end + +end: + ret void +} + +define void @make_indirect_call(ptr %f, i32 %x) convergent alwaysinline { +; CHECK-LABEL: @make_indirect_call( +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: call void [[F:%.*]](i32 [[X:%.*]]) #[[ATTR2:[0-9]+]] [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; + %token = call token @llvm.experimental.convergence.entry() + call void %f(i32 %x) convergent [ "convergencectrl"(token %token) ] + ret void +} + +define void @test_indirect_call() convergent { +; CHECK-LABEL: @test_indirect_call( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: call void @f(i32 0) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; +entry: + %token = call token @llvm.experimental.convergence.entry() + call void @make_indirect_call(ptr @convergent_callee, i32 0) [ "convergencectrl"(token %token) ] + ret void +} + +define void @recurse() convergent alwaysinline { +; CHECK-LABEL: @recurse( +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: call void @recurse() [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; + %token = call token @llvm.experimental.convergence.entry() + call void @recurse() [ "convergencectrl"(token %token) ] + ret void +} + +define void @test_recursive_call() convergent { +; CHECK-LABEL: @test_recursive_call( +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: call void @recurse() [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; + %token = call token @llvm.experimental.convergence.entry() + call void @recurse() [ "convergencectrl"(token %token) ] + ret void +} + +define i32 @outer_g(i32 %x) convergent alwaysinline { +; CHECK-LABEL: @outer_g( +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: [[Y:%.*]] = call i32 @g(i32 [[X:%.*]]) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret i32 [[Y]] +; + %token = call token @llvm.experimental.convergence.entry() + %y = call i32 @g(i32 %x) [ "convergencectrl"(token %token) ] + ret i32 %y +} + +define void @test_two_calls() convergent { +; CHECK-LABEL: @test_two_calls( +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: [[Y_I:%.*]] = call i32 @g(i32 23) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: call void @f(i32 [[Y_I]]) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; + %token = call token @llvm.experimental.convergence.entry() + %x = call i32 @outer_g(i32 23) [ "convergencectrl"(token %token) ] + call void @convergent_callee(i32 %x) [ "convergencectrl"(token %token) ] + ret void +} + +declare void @f(i32) convergent +declare i32 @g(i32) convergent + +declare token @llvm.experimental.convergence.entry() +declare token @llvm.experimental.convergence.anchor() +declare token @llvm.experimental.convergence.loop()