diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1671,17 +1671,26 @@ detect_has_single_result_fold::value, AbstractOperation::FoldHookFn> getFoldHookFnImpl() { - return &foldSingleResultHook; - } - /// The internal implementation of `getFoldHookFn` above that is invoked if - /// the operation is not single result and defines a `fold` method. + return [](Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return foldSingleResultHook(op, operands, results); + }; + } + /// The internal implementation of `getFoldHookFn` above that + /// is invoked if the operation is not single result and + /// defines a `fold` method. template static std::enable_if_t, Traits...>::value && detect_has_fold::value, AbstractOperation::FoldHookFn> getFoldHookFnImpl() { - return &foldHook; + // Windows fails to compile when giving only a pointer to the function. + // This is why we create a unique_function with a lambda instead. + return [](Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return foldHook(op, operands, results); + }; } /// The internal implementation of `getFoldHookFn` above that is invoked if /// the operation does not define a `fold` method. @@ -1691,7 +1700,11 @@ AbstractOperation::FoldHookFn> getFoldHookFnImpl() { // In this case, we only need to fold the traits of the operation. - return &op_definition_impl::foldTraits; + return [](Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return op_definition_impl::foldTraits(op, operands, + results); + }; } /// Return the result of folding a single result operation that defines a /// `fold` method. @@ -1735,7 +1748,8 @@ } /// Implementation of `GetHasTraitFn` static AbstractOperation::HasTraitFn getHasTraitFn() { - return &op_definition_impl::hasTrait; + return + [](TypeID id) { return op_definition_impl::hasTrait(id); }; } /// Implementation of `ParseAssemblyFn` AbstractOperation hook. static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() { @@ -1751,7 +1765,10 @@ static std::enable_if_t::value, AbstractOperation::PrintAssemblyFn> getPrintAssemblyFnImpl() { - return &OpState::print; + // We need to get the right overload before passing the function + // to the unique_function constructor. + void (*fun)(Operation *, OpAsmPrinter &) = &OpState::print; + return fun; } /// The internal implementation of `getPrintAssemblyFn` that is invoked when /// the concrete operation defines a `print` method. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -67,14 +67,17 @@ /// the concrete operation types. class AbstractOperation { public: - using GetCanonicalizationPatternsFn = void (*)(RewritePatternSet &, - MLIRContext *); - using FoldHookFn = LogicalResult (*)(Operation *, ArrayRef, - SmallVectorImpl &); - using HasTraitFn = bool (*)(TypeID); - using ParseAssemblyFn = ParseResult (*)(OpAsmParser &, OperationState &); - using PrintAssemblyFn = void (*)(Operation *, OpAsmPrinter &); - using VerifyInvariantsFn = LogicalResult (*)(Operation *); + using GetCanonicalizationPatternsFn = + llvm::unique_function; + using FoldHookFn = llvm::unique_function, SmallVectorImpl &) const>; + using HasTraitFn = llvm::unique_function; + using ParseAssemblyFn = + llvm::unique_function; + using PrintAssemblyFn = + llvm::unique_function; + using VerifyInvariantsFn = + llvm::unique_function; /// This is the name of the operation. const Identifier name; @@ -89,7 +92,7 @@ ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const; /// Return the static hook for parsing this operation assembly. - ParseAssemblyFn getParseAssemblyFn() const { return parseAssemblyFn; } + const ParseAssemblyFn &getParseAssemblyFn() const { return parseAssemblyFn; } /// This hook implements the AsmPrinter for this operation. void printAssembly(Operation *op, OpAsmPrinter &p) const { @@ -175,20 +178,21 @@ /// Register a new operation in a Dialect object. /// The use of this method is in general discouraged in favor of /// 'insert(dialect)'. - static void insert(StringRef name, Dialect &dialect, TypeID typeID, - ParseAssemblyFn parseAssembly, - PrintAssemblyFn printAssembly, - VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook, - GetCanonicalizationPatternsFn getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait); + static void + insert(StringRef name, Dialect &dialect, TypeID typeID, + ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, + VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, + GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, + detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait); private: AbstractOperation(StringRef name, Dialect &dialect, TypeID typeID, - ParseAssemblyFn parseAssembly, - PrintAssemblyFn printAssembly, - VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook, - GetCanonicalizationPatternsFn getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait); + ParseAssemblyFn &&parseAssembly, + PrintAssemblyFn &&printAssembly, + VerifyInvariantsFn &&verifyInvariants, + FoldHookFn &&foldHook, + GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, + detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait); /// A map of interfaces that were registered to this operation. detail::InterfaceMap interfaceMap; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -696,13 +696,15 @@ void AbstractOperation::insert( StringRef name, Dialect &dialect, TypeID typeID, - ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly, - VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook, - GetCanonicalizationPatternsFn getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait) { - AbstractOperation opInfo( - name, dialect, typeID, parseAssembly, printAssembly, verifyInvariants, - foldHook, getCanonicalizationPatterns, std::move(interfaceMap), hasTrait); + ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, + VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, + GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, + detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait) { + AbstractOperation opInfo(name, dialect, typeID, std::move(parseAssembly), + std::move(printAssembly), + std::move(verifyInvariants), std::move(foldHook), + std::move(getCanonicalizationPatterns), + std::move(interfaceMap), std::move(hasTrait)); auto &impl = dialect.getContext()->getImpl(); assert(impl.multiThreadedExecutionContext == 0 && @@ -717,16 +719,18 @@ AbstractOperation::AbstractOperation( StringRef name, Dialect &dialect, TypeID typeID, - ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly, - VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook, - GetCanonicalizationPatternsFn getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait) + ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, + VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, + GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, + detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait) : name(Identifier::get(name, dialect.getContext())), dialect(dialect), typeID(typeID), interfaceMap(std::move(interfaceMap)), - foldHookFn(foldHook), - getCanonicalizationPatternsFn(getCanonicalizationPatterns), - hasTraitFn(hasTrait), parseAssemblyFn(parseAssembly), - printAssemblyFn(printAssembly), verifyInvariantsFn(verifyInvariants) {} + foldHookFn(std::move(foldHook)), + getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)), + hasTraitFn(std::move(hasTrait)), + parseAssemblyFn(std::move(parseAssembly)), + printAssemblyFn(std::move(printAssembly)), + verifyInvariantsFn(std::move(verifyInvariants)) {} //===----------------------------------------------------------------------===// // AbstractType diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1803,7 +1803,8 @@ // This is the actual hook for the custom op parsing, usually implemented by // the op itself (`Op::parse()`). We retrieve it either from the // AbstractOperation or from the Dialect. - std::function parseAssemblyFn; + llvm::function_ref + parseAssemblyFn; bool isIsolatedFromAbove = false; if (opDefinition) {