diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2850,15 +2850,37 @@ auto funcOp = dyn_cast(callable); if (!callOp || !funcOp) return false; - return isLegalToInlineCallAttributes(callOp) && - isLegalToInlineFuncAttributes(funcOp); + // TODO: Handle argument and result attributes; + if (funcOp.getArgAttrs() || funcOp.getResAttrs()) + return false; + // TODO: Handle exceptions. + if (funcOp.getPersonality()) + return false; + if (funcOp.getPassthrough()) { + // TODO: Used attributes should not be passthrough. + DenseSet disallowed( + {StringAttr::get(funcOp->getContext(), "noduplicate"), + StringAttr::get(funcOp->getContext(), "noinline"), + StringAttr::get(funcOp->getContext(), "optnone"), + StringAttr::get(funcOp->getContext(), "presplitcoroutine"), + StringAttr::get(funcOp->getContext(), "returns_twice"), + StringAttr::get(funcOp->getContext(), "strictfp")}); + if (llvm::any_of(*funcOp.getPassthrough(), [&](Attribute attr) { + auto stringAttr = dyn_cast(attr); + if (!stringAttr) + return false; + return disallowed.contains(stringAttr); + })) + return false; + } + return true; } bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { return true; } - /// Conservative allowlist-based inlining of operations supported so far. + /// Conservative allowlist of operations supported so far. bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final { if (isPure(op)) return true; @@ -2919,53 +2941,6 @@ // which newly inlined block was previously the entry block of the callee. moveConstantAllocasToEntryBlock(inlinedBlocks); } - -private: - /// Returns true if all attributes of `callOp` are handled during inlining. - [[nodiscard]] static bool isLegalToInlineCallAttributes(LLVM::CallOp callOp) { - return all_of(callOp.getAttributeNames(), [&](StringRef attrName) { - return llvm::StringSwitch(attrName) - // TODO: Propagate and update branch weights. - .Case("branch_weights", !callOp.getBranchWeights()) - .Case("callee", true) - .Case("fastmathFlags", true) - .Default(false); - }); - } - - /// Returns true if all attributes of `funcOp` are handled during inlining. - [[nodiscard]] static bool - isLegalToInlineFuncAttributes(LLVM::LLVMFuncOp funcOp) { - return all_of(funcOp.getAttributeNames(), [&](StringRef attrName) { - return llvm::StringSwitch(attrName) - .Case("CConv", true) - .Case("arg_attrs", ([&]() { - if (!funcOp.getArgAttrs()) - return true; - return llvm::all_of(funcOp.getArgAttrs().value(), - [&](Attribute) { - // TODO: Handle argument attributes. - return false; - }); - })()) - .Case("dso_local", true) - .Case("function_entry_count", true) - .Case("function_type", true) - // TODO: Once the garbage collector attribute is supported on - // LLVM::CallOp, make sure that the garbage collector matches. - .Case("garbageCollector", !funcOp.getGarbageCollector()) - .Case("linkage", true) - .Case("memory", true) - .Case("passthrough", !funcOp.getPassthrough()) - // Exception handling is not yet supported, so bail out if the - // personality is set. - .Case("personality", !funcOp.getPersonality()) - // TODO: Handle result attributes. - .Case("res_attrs", !funcOp.getResAttrs()) - .Case("sym_name", true) - .Default(false); - }); - } }; } // end anonymous namespace diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -83,7 +83,7 @@ // CHECK-NEXT: llvm.return %[[CST]] llvm.func @caller() -> (i32) { // Include all call attributes that don't prevent inlining. - %0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath } : () -> (i32) + %0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath, branch_weights = dense<42> : vector<1xi32> } : () -> (i32) llvm.return %0 : i32 } @@ -147,32 +147,42 @@ // ----- -llvm.func @callee() -> (i32) attributes { passthrough = ["foo"] } { - %0 = llvm.mlir.constant(42 : i32) : i32 - llvm.return %0 : i32 +llvm.func @callee() attributes { passthrough = ["foo", "bar"] } { + llvm.return } // CHECK-LABEL: llvm.func @caller -// CHECK-NEXT: llvm.call @callee -// CHECK-NEXT: return -llvm.func @caller() -> (i32) { - %0 = llvm.call @callee() : () -> (i32) - llvm.return %0 : i32 +// CHECK-NEXT: llvm.return +llvm.func @caller() { + llvm.call @callee() : () -> () + llvm.return } // ----- -llvm.func @callee() -> (i32) attributes { garbageCollector = "foo" } { - %0 = llvm.mlir.constant(42 : i32) : i32 - llvm.return %0 : i32 -} +llvm.func @callee_noinline() attributes { passthrough = ["noinline"] } +llvm.func @callee_optnone() attributes { passthrough = ["optnone"] } +llvm.func @callee_noduplicate() attributes { passthrough = ["noduplicate"] } +llvm.func @callee_presplitcoroutine() attributes { passthrough = ["presplitcoroutine"] } +llvm.func @callee_returns_twice() attributes { passthrough = ["returns_twice"] } +llvm.func @callee_strictfp() attributes { passthrough = ["strictfp"] } // CHECK-LABEL: llvm.func @caller -// CHECK-NEXT: llvm.call @callee -// CHECK-NEXT: return -llvm.func @caller() -> (i32) { - %0 = llvm.call @callee() : () -> (i32) - llvm.return %0 : i32 +// CHECK-NEXT: llvm.call @callee_noinline +// CHECK-NEXT: llvm.call @callee_optnone +// CHECK-NEXT: llvm.call @callee_noduplicate +// CHECK-NEXT: llvm.call @callee_presplitcoroutine +// CHECK-NEXT: llvm.call @callee_returns_twice +// CHECK-NEXT: llvm.call @callee_strictfp +// CHECK-NEXT: llvm.return +llvm.func @caller() { + llvm.call @callee_noinline() : () -> () + llvm.call @callee_optnone() : () -> () + llvm.call @callee_noduplicate() : () -> () + llvm.call @callee_presplitcoroutine() : () -> () + llvm.call @callee_returns_twice() : () -> () + llvm.call @callee_strictfp() : () -> () + llvm.return } // ----- @@ -191,20 +201,6 @@ // ----- -llvm.func @callee() { - llvm.return -} - -// CHECK-LABEL: llvm.func @caller -// CHECK-NEXT: llvm.call @callee -// CHECK-NEXT: llvm.return -llvm.func @caller() { - llvm.call @callee() { branch_weights = dense<42> : vector<1xi32> } : () -> () - llvm.return -} - -// ----- - llvm.func @static_alloca() -> f32 { %0 = llvm.mlir.constant(4 : i32) : i32 %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr