diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp @@ -128,10 +128,34 @@ return allocaOp; } +/// Returns true if the given argument or result attribute is supported by the +/// inliner, false otherwise. +static bool isArgOrResAttrSupported(NamedAttribute attr) { + if (attr.getName() == LLVM::LLVMDialect::getAlignAttrName()) + return false; + if (attr.getName() == LLVM::LLVMDialect::getInAllocaAttrName()) + return false; + if (attr.getName() == LLVM::LLVMDialect::getNoAliasAttrName()) + return false; + return true; +} + namespace { struct LLVMInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; + LLVMInlinerInterface(Dialect *dialect) + : DialectInlinerInterface(dialect), + // Cache set of StringAttrs for fast lookup in `isLegalToInline`. + disallowedFunctionAttrs({ + StringAttr::get(dialect->getContext(), "noduplicate"), + StringAttr::get(dialect->getContext(), "noinline"), + StringAttr::get(dialect->getContext(), "optnone"), + StringAttr::get(dialect->getContext(), "presplitcoroutine"), + StringAttr::get(dialect->getContext(), "returns_twice"), + StringAttr::get(dialect->getContext(), "strictfp"), + }) {} + bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final { if (!wouldBeCloned) @@ -149,24 +173,28 @@ return false; } if (auto attrs = funcOp.getArgAttrs()) { - for (Attribute attr : *attrs) { - auto attrDict = cast(attr); + for (DictionaryAttr attrDict : attrs->getAsRange()) { for (NamedAttribute attr : attrDict) { - if (attr.getName() == LLVM::LLVMDialect::getByValAttrName()) - continue; - // TODO: Handle all argument attributes; - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": unhandled argument attribute \"" - << attr.getName() << "\"\n"); - return false; + if (!isArgOrResAttrSupported(attr)) { + LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() + << ": unhandled argument attribute " + << attr.getName() << "\n"); + return false; + } } } } - // TODO: Handle result attributes; - if (funcOp.getResAttrs()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": unhandled result attribute\n"); - return false; + if (auto attrs = funcOp.getResAttrs()) { + for (DictionaryAttr attrDict : attrs->getAsRange()) { + for (NamedAttribute attr : attrDict) { + if (!isArgOrResAttrSupported(attr)) { + LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() + << ": unhandled return attribute " + << attr.getName() << "\n"); + return false; + } + } + } } // TODO: Handle exceptions. if (funcOp.getPersonality()) { @@ -176,18 +204,11 @@ } if (funcOp.getPassthrough()) { // TODO: Used attributes should not be passthrough. - DenseSet disallowed( - {StringAttr::get(funcOp->getContext(), "noduplicate"), - StringAttr::get(funcOp->getContext(), "noinline"), - StringAttr::get(funcOp->getContext(), "optnone"), - StringAttr::get(funcOp->getContext(), "presplitcoroutine"), - StringAttr::get(funcOp->getContext(), "returns_twice"), - StringAttr::get(funcOp->getContext(), "strictfp")}); if (llvm::any_of(*funcOp.getPassthrough(), [&](Attribute attr) { auto stringAttr = dyn_cast(attr); if (!stringAttr) return false; - if (disallowed.contains(stringAttr)) { + if (disallowedFunctionAttrs.contains(stringAttr)) { LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() << ": found disallowed function attribute " @@ -284,6 +305,11 @@ // which newly inlined block was previously the entry block of the callee. moveConstantAllocasToEntryBlock(inlinedBlocks); } + + // Keeping this (immutable) state on the interface allows us to look up + // StringAttrs instead of looking up strings, since StringAttrs are bound to + // the current context and thus cannot be initialized as static fields. + const DenseSet disallowedFunctionAttrs; }; } // end anonymous namespace diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -396,3 +396,43 @@ llvm.call @with_byval_arg(%ptr) : (!llvm.ptr) -> () llvm.return } + +// ----- + +llvm.func @ignored_attrs(%ptr : !llvm.ptr { llvm.inreg, llvm.nocapture, llvm.nofree, llvm.preallocated = i32, llvm.returned, llvm.alignstack = 32 : i64, llvm.writeonly, llvm.noundef, llvm.nonnull }, %x : i32 { llvm.zeroext }) -> (!llvm.ptr { llvm.noundef, llvm.inreg, llvm.nonnull }) { + llvm.return %ptr : !llvm.ptr +} + +// CHECK-LABEL: @test_ignored_attrs +// CHECK-NOT: llvm.call +// CHECK-NEXT: llvm.return +llvm.func @test_ignored_attrs(%ptr : !llvm.ptr, %x : i32) { + llvm.call @ignored_attrs(%ptr, %x) : (!llvm.ptr, i32) -> (!llvm.ptr) + llvm.return +} + +// ----- + +llvm.func @disallowed_arg_attr(%ptr : !llvm.ptr { llvm.align = 16 : i32 }) { + llvm.return +} + +// CHECK-LABEL: @test_disallow_arg_attr +// CHECK-NEXT: llvm.call +llvm.func @test_disallow_arg_attr(%ptr : !llvm.ptr) { + llvm.call @disallowed_arg_attr(%ptr) : (!llvm.ptr) -> () + llvm.return +} + +// ----- + +llvm.func @disallowed_res_attr(%ptr : !llvm.ptr) -> (!llvm.ptr { llvm.noalias }) { + llvm.return %ptr : !llvm.ptr +} + +// CHECK-LABEL: @test_disallow_res_attr +// CHECK-NEXT: llvm.call +llvm.func @test_disallow_res_attr(%ptr : !llvm.ptr) { + llvm.call @disallowed_res_attr(%ptr) : (!llvm.ptr) -> (!llvm.ptr) + llvm.return +}