diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2852,6 +2852,22 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + if (!wouldBeCloned) + return false; + auto callOp = dyn_cast(call); + auto funcOp = dyn_cast(callable); + if (!callOp || !funcOp) + return false; + return isLegalToInlineCallAttributes(callOp) && + isLegalToInlineFuncAttributes(funcOp); + } + + bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { + return true; + } + /// Conservative allowlist-based inlining of operations supported so far. bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final { if (isPure(op)) @@ -2869,22 +2885,83 @@ return false; return true; }) + .Case([](auto) { return true; }) .Default([](auto) { return false; }); } - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. Required when the region has only one block. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { - // Only handle "llvm.return" here. - auto returnOp = dyn_cast(op); + /// Handle the given inlined return by replacing it with a branch. This + /// overload is called when the inlined region has more than one block. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); if (!returnOp) return; + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create(op->getLoc(), returnOp.getOperands(), newDest); + op->erase(); + } + + /// Handle the given inlined return by replacing the uses of the call with the + /// operands of the return. This overload is called when the inlined region + /// only contains one block. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Return will be the only terminator present. + auto returnOp = cast(op); + // Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + for (const auto &[dst, src] : + llvm::zip(valuesToRepl, returnOp.getOperands())) + dst.replaceAllUsesWith(src); + } + +private: + /// Returns true if all attributes of `callOp` are handled during inlining. + [[nodiscard]] static bool isLegalToInlineCallAttributes(LLVM::CallOp callOp) { + return all_of(callOp.getAttributeNames(), [&](StringRef attrName) { + return llvm::StringSwitch(attrName) + // TODO: Propagate and update branch weights. + .Case("branch_weights", !callOp.getBranchWeights()) + .Case("callee", true) + .Case("fastmathFlags", true) + .Default(false); + }); + } + + /// Returns true if all attributes of `funcOp` are handled during inlining. + [[nodiscard]] static bool + isLegalToInlineFuncAttributes(LLVM::LLVMFuncOp funcOp) { + return all_of(funcOp.getAttributeNames(), [&](StringRef attrName) { + return llvm::StringSwitch(attrName) + .Case("CConv", true) + .Case("arg_attrs", ([&]() { + if (!funcOp.getArgAttrs()) + return true; + return llvm::all_of(funcOp.getArgAttrs().value(), + [&](Attribute) { + // TODO: Handle argument attributes. + return false; + }); + })()) + .Case("dso_local", true) + .Case("function_entry_count", true) + .Case("function_type", true) + // TODO: Once the garbage collector attribute is supported on + // LLVM::CallOp, make sure that the garbage collector matches. + .Case("garbageCollector", !funcOp.getGarbageCollector()) + .Case("linkage", true) + .Case("passthrough", !funcOp.getPassthrough()) + // Exception handling is not yet supported, so bail out if the + // personality is set. + .Case("personality", !funcOp.getPersonality()) + // TODO: Handle result attributes. + .Case("res_attrs", !funcOp.getResAttrs()) + .Case("sym_name", true) + .Default(false); + }); } }; } // end anonymous namespace diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -41,7 +41,7 @@ llvm.return } -func.func private @with_mem_attr(%ptr : !llvm.ptr) -> () { +func.func private @with_mem_attr(%ptr : !llvm.ptr) { %0 = llvm.mlir.constant(42 : i32) : i32 // Do not inline load/store operations that carry attributes requiring // handling while inlining, until this is supported by the inliner. @@ -52,7 +52,7 @@ // CHECK-LABEL: func.func @test_not_inline // CHECK-NEXT: call @with_mem_attr // CHECK-NEXT: return -func.func @test_not_inline(%ptr : !llvm.ptr) -> () { +func.func @test_not_inline(%ptr : !llvm.ptr) { call @with_mem_attr(%ptr) : (!llvm.ptr) -> () return } @@ -69,4 +69,134 @@ func.func @llvm_ret(%arg0 : i32) -> i32 { %res = call @func(%arg0) : (i32) -> (i32) return %res : i32 + +// Include all function attributes that don't prevent inlining +llvm.func internal fastcc @callee() -> (i32) attributes { function_entry_count = 42 : i64, dso_local } { + %0 = llvm.mlir.constant(42 : i32) : i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: llvm.func @caller +// CHECK-NEXT: %[[CST:.+]] = llvm.mlir.constant +// CHECK-NEXT: llvm.return %[[CST]] +llvm.func @caller() -> (i32) { + // Include all call attributes that don't prevent inlining. + %0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath } : () -> (i32) + llvm.return %0 : i32 +} + +// ----- + +llvm.func @foo() -> (i32) attributes { passthrough = ["noinline"] } { + %0 = llvm.mlir.constant(0 : i32) : i32 + llvm.return %0 : i32 +} + +llvm.func @bar() -> (i32) attributes { passthrough = ["noinline"] } { + %0 = llvm.mlir.constant(1 : i32) : i32 + llvm.return %0 : i32 +} + +llvm.func @callee_with_multiple_blocks(%cond: i1) -> (i32) { + llvm.cond_br %cond, ^bb1, ^bb2 +^bb1: + %0 = llvm.call @foo() : () -> (i32) + llvm.br ^bb3(%0: i32) +^bb2: + %1 = llvm.call @bar() : () -> (i32) + llvm.br ^bb3(%1: i32) +^bb3(%arg: i32): + llvm.return %arg : i32 +} + +// CHECK-LABEL: llvm.func @caller +// CHECK-NEXT: llvm.cond_br {{.+}}, ^[[BB1:.+]], ^[[BB2:.+]] +// CHECK-NEXT: ^[[BB1]]: +// CHECK-NEXT: llvm.call @foo +// CHECK-NEXT: llvm.br ^[[BB3:[a-zA-Z0-9_]+]] +// CHECK-NEXT: ^[[BB2]]: +// CHECK-NEXT: llvm.call @bar +// CHECK-NEXT: llvm.br ^[[BB3]] +// CHECK-NEXT: ^[[BB3]] +// CHECK-NEXT: llvm.br ^[[BB4:[a-zA-Z0-9_]+]] +// CHECK-NEXT: ^[[BB4]] +// CHECK-NEXT: llvm.return +llvm.func @caller(%cond: i1) -> (i32) { + %0 = llvm.call @callee_with_multiple_blocks(%cond) : (i1) -> (i32) + llvm.return %0 : i32 +} + +// ----- + +llvm.func @personality() -> i32 + +llvm.func @callee() -> (i32) attributes { personality = @personality } { + %0 = llvm.mlir.constant(42 : i32) : i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: llvm.func @caller +// CHECK-NEXT: llvm.call @callee +// CHECK-NEXT: return +llvm.func @caller() -> (i32) { + %0 = llvm.call @callee() : () -> (i32) + llvm.return %0 : i32 +} + +// ----- + +llvm.func @callee() -> (i32) attributes { passthrough = ["foo"] } { + %0 = llvm.mlir.constant(42 : i32) : i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: llvm.func @caller +// CHECK-NEXT: llvm.call @callee +// CHECK-NEXT: return +llvm.func @caller() -> (i32) { + %0 = llvm.call @callee() : () -> (i32) + llvm.return %0 : i32 +} + +// ----- + +llvm.func @callee() -> (i32) attributes { garbageCollector = "foo" } { + %0 = llvm.mlir.constant(42 : i32) : i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: llvm.func @caller +// CHECK-NEXT: llvm.call @callee +// CHECK-NEXT: return +llvm.func @caller() -> (i32) { + %0 = llvm.call @callee() : () -> (i32) + llvm.return %0 : i32 +} + +// ----- + +llvm.func @callee(%ptr : !llvm.ptr {llvm.byval = !llvm.ptr}) -> (!llvm.ptr) { + llvm.return %ptr : !llvm.ptr +} + +// CHECK-LABEL: llvm.func @caller +// CHECK-NEXT: llvm.call @callee +// CHECK-NEXT: return +llvm.func @caller(%ptr : !llvm.ptr) -> (!llvm.ptr) { + %0 = llvm.call @callee(%ptr) : (!llvm.ptr) -> (!llvm.ptr) + llvm.return %0 : !llvm.ptr +} + +// ----- + +llvm.func @callee() { + llvm.return +} + +// CHECK-LABEL: llvm.func @caller +// CHECK-NEXT: llvm.call @callee +// CHECK-NEXT: llvm.return +llvm.func @caller() { + llvm.call @callee() { branch_weights = dense<42> : vector<1xi32> } : () -> () + llvm.return }