diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1744,41 +1744,21 @@ /// hooks. /// Implementation of `FoldHookFn` OperationName hook. static OperationName::FoldHookFn getFoldHookFn() { - return getFoldHookFnImpl(); - } - /// The internal implementation of `getFoldHookFn` above that is invoked if - /// the operation is single result and defines a `fold` method. - template - static std::enable_if_t, - Traits...>::value && - detect_has_single_result_fold::value, - OperationName::FoldHookFn> - getFoldHookFnImpl() { - return [](Operation *op, ArrayRef operands, - SmallVectorImpl &results) { - return foldSingleResultHook(op, operands, results); - }; - } - /// The internal implementation of `getFoldHookFn` above that is invoked if - /// the operation is not single result and defines a `fold` method. - template - static std::enable_if_t, - Traits...>::value && - detect_has_fold::value, - OperationName::FoldHookFn> - getFoldHookFnImpl() { - return [](Operation *op, ArrayRef operands, - SmallVectorImpl &results) { - return foldHook(op, operands, results); - }; - } - /// The internal implementation of `getFoldHookFn` above that is invoked if - /// the operation does not define a `fold` method. - template - static std::enable_if_t::value && - !detect_has_fold::value, - OperationName::FoldHookFn> - getFoldHookFnImpl() { + // If the operation is single result and defines a `fold` method. + if constexpr (llvm::is_one_of, + Traits...>::value && + detect_has_single_result_fold::value) + return [](Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return foldSingleResultHook(op, operands, results); + }; + // The operation is not single result and defines a `fold` method. + if constexpr (detect_has_fold::value) + return [](Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return foldHook(op, operands, results); + }; + // The operation does not define a `fold` method. return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { // In this case, we only need to fold the traits of the operation. @@ -1837,31 +1817,16 @@ } /// Implementation of `PrintAssemblyFn` OperationName hook. static OperationName::PrintAssemblyFn getPrintAssemblyFn() { - return getPrintAssemblyFnImpl(); - } - /// The internal implementation of `getPrintAssemblyFn` that is invoked when - /// the concrete operation does not define a `print` method. - template - static std::enable_if_t::value, - OperationName::PrintAssemblyFn> - getPrintAssemblyFnImpl() { + if constexpr (detect_has_print::value) + return [](Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { + OpState::printOpName(op, p, defaultDialect); + return cast(op).print(p); + }; return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) { return OpState::print(op, printer, defaultDialect); }; } - /// The internal implementation of `getPrintAssemblyFn` that is invoked when - /// the concrete operation defines a `print` method. - template - static std::enable_if_t::value, - OperationName::PrintAssemblyFn> - getPrintAssemblyFnImpl() { - return &printAssembly; - } - static void printAssembly(Operation *op, OpAsmPrinter &p, - StringRef defaultDialect) { - OpState::printOpName(op, p, defaultDialect); - return cast(op).print(p); - } + /// Implementation of `PopulateDefaultAttrsFn` OperationName hook. static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() { return ConcreteType::populateDefaultAttrs;