diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -23,12 +23,6 @@ class OpBuilder; class Type; -using DialectConstantDecodeHook = - std::function; -using DialectConstantFoldHook = std::function, SmallVectorImpl &)>; -using DialectExtractElementHook = - std::function)>; using DialectAllocatorFunction = std::function; /// Dialects are groups of MLIR operations and behavior associated with the @@ -63,38 +57,6 @@ /// These are represented with OpaqueType. bool allowsUnknownTypes() const { return unknownTypesAllowed; } - //===--------------------------------------------------------------------===// - // Constant Hooks - //===--------------------------------------------------------------------===// - - /// Registered fallback constant fold hook for the dialect. Like the constant - /// fold hook of each operation, it attempts to constant fold the operation - /// with the specified constant operand values - the elements in "operands" - /// will correspond directly to the operands of the operation, but may be null - /// if non-constant. If constant folding is successful, this fills in the - /// `results` vector. If not, this returns failure and `results` is - /// unspecified. - DialectConstantFoldHook constantFoldHook = - [](Operation *op, ArrayRef operands, - SmallVectorImpl &results) { return failure(); }; - - /// Registered hook to decode opaque constants associated with this - /// dialect. The hook function attempts to decode an opaque constant tensor - /// into a tensor with non-opaque content. If decoding is successful, this - /// method returns false and sets 'output' attribute. If not, it returns true - /// and leaves 'output' unspecified. The default hook fails to decode. - DialectConstantDecodeHook decodeHook = - [](const OpaqueElementsAttr input, ElementsAttr &output) { return true; }; - - /// Registered hook to extract an element from an opaque constant associated - /// with this dialect. If element has been successfully extracted, this - /// method returns that element. If not, it returns an empty attribute. - /// The default hook fails to extract an element. - DialectExtractElementHook extractElementHook = - [](const OpaqueElementsAttr input, ArrayRef index) { - return Attribute(); - }; - /// Registered hook to materialize a single constant operation from a given /// attribute value with the desired resultant type. This method should use /// the provided builder to create the operation without changing the diff --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h deleted file mode 100644 --- a/mlir/include/mlir/IR/DialectHooks.h +++ /dev/null @@ -1,90 +0,0 @@ -//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines abstraction and registration mechanism for dialect hooks. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_DIALECT_HOOKS_H -#define MLIR_IR_DIALECT_HOOKS_H - -#include "mlir/IR/Dialect.h" -#include "llvm/Support/raw_ostream.h" - -namespace mlir { -using DialectHooksSetter = std::function; - -/// Dialect hooks allow external components to register their functions to -/// be called for specific tasks specialized per dialect, such as decoding -/// of opaque constants. To register concrete dialect hooks, one should -/// define a DialectHooks subclass and use it as a template -/// argument to DialectHooksRegistration. For example, -/// class MyHooks : public DialectHooks {...}; -/// static DialectHooksRegistration hooksReg; -/// The subclass should override DialectHook methods for supported hooks. -class DialectHooks { -public: - // Returns hook to constant fold an operation. - DialectConstantFoldHook getConstantFoldHook() { return nullptr; } - // Returns hook to decode opaque constant tensor. - DialectConstantDecodeHook getDecodeHook() { return nullptr; } - // Returns hook to extract an element of an opaque constant tensor. - DialectExtractElementHook getExtractElementHook() { return nullptr; } - -private: - /// Registers a function that will set hooks in the registered dialects. - /// Registrations are deduplicated by dialect TypeID and only the first - /// registration will be used. - static void registerDialectHooksSetter(TypeID typeID, - const DialectHooksSetter &function); - template - friend void registerDialectHooks(StringRef dialectName); -}; - -void registerDialectHooksSetter(TypeID typeID, - const DialectHooksSetter &function); - -/// Utility to register dialect hooks. Client can register their dialect hooks -/// with the global registry by calling -/// registerDialectHooks("dialect_namespace"); -template -void registerDialectHooks(StringRef dialectName) { - DialectHooks::registerDialectHooksSetter( - TypeID::get(), [dialectName](MLIRContext *ctx) { - Dialect *dialect = ctx->getRegisteredDialect(dialectName); - if (!dialect) { - llvm::errs() << "error: cannot register hooks for unknown dialect '" - << dialectName << "'\n"; - abort(); - } - // Set hooks. - ConcreteHooks hooks; - if (auto h = hooks.getConstantFoldHook()) - dialect->constantFoldHook = h; - if (auto h = hooks.getDecodeHook()) - dialect->decodeHook = h; - if (auto h = hooks.getExtractElementHook()) - dialect->extractElementHook = h; - }); -} - -/// DialectHooksRegistration provides a global initializer that registers -/// a dialect hooks setter routine. -/// Usage: -/// -/// // At namespace scope. -/// static DialectHooksRegistration Unused("dialect_namespace"); -template struct DialectHooksRegistration { - DialectHooksRegistration(StringRef dialectName) { - registerDialectHooks(dialectName); - } -}; - -} // namespace mlir - -#endif diff --git a/mlir/include/mlir/Interfaces/ConstantFoldInterfaces.h b/mlir/include/mlir/Interfaces/ConstantFoldInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ConstantFoldInterfaces.h @@ -0,0 +1,40 @@ +//===- ControlFlowInterfaces.h - ControlFlow Interfaces ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_INTERFACES_CONSTANTFOLDINTERFACES_H_ +#define MLIR_INTERFACES_CONSTANTFOLDINTERFACES_H_ + +#include "mlir/IR/DialectInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +class Attribute; +class OpFoldResult; + +/// Define a constant fold interface to allow for dialects to opt-in specific +/// constant folding for operations they define. +class DialectConstantFoldInterface + : public DialectInterface::Base { +public: + /// Registered fallback constant fold for the dialect. Like the constant fold + /// hook of each operation, it attempts to constant fold the operation with + /// the specified constant operand values - the elements in "operands" will + /// correspond directly to the operands of the operation, but may be null if + /// non-constant. If constant folding is successful, this fills in the + /// `results` vector. If not, this returns failure and `results` is + /// unspecified. + virtual LogicalResult + constantFold(Operation *op, ArrayRef operands, + SmallVectorImpl &results) const { + return failure(); + } +}; + +} // end namespace mlir + +#endif // MLIR_INTERFACES_CONSTANTFOLDINTERFACES_H_ diff --git a/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h b/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h @@ -0,0 +1,35 @@ +//===- DecodeAttributesInterfaces.h - DecodeAttributes Interfaces -*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_ +#define MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { + +/// Define an interface to decode opaque constant tensor. +class DialectDecodeAttributesInterface + : public DialectInterface::Base { +public: + /// Registered hook to decode opaque constants associated with this + /// dialect. The hook function attempts to decode an opaque constant tensor + /// into a tensor with non-opaque content. If decoding is successful, this + /// method returns success() and sets 'output' attribute. If not, it returns + /// failure() and leaves 'output' unspecified. The default hook fails to + /// decode. + virtual LogicalResult decode(OpaqueElementsAttr input, + ElementsAttr &output) const { + return failure(); + } +}; + +} // end namespace mlir + +#endif // MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_ diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/DecodeAttributesInterfaces.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Endian.h" @@ -1227,17 +1228,20 @@ /// element, then a null attribute is returned. Attribute OpaqueElementsAttr::getValue(ArrayRef index) const { assert(isValidIndex(index) && "expected valid multi-dimensional index"); - if (Dialect *dialect = getDialect()) - return dialect->extractElementHook(*this, index); return Attribute(); } Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } bool OpaqueElementsAttr::decode(ElementsAttr &result) { - if (auto *d = getDialect()) - return d->decodeHook(*this, result); - return true; + auto *d = getDialect(); + if (!d) + return true; + auto *interface = + d->getRegisteredInterface(); + if (!interface) + return true; + return failed(interface->decode(*this, result)); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -8,7 +8,6 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectHooks.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectInterface.h" #include "mlir/IR/MLIRContext.h" @@ -31,10 +30,6 @@ static llvm::ManagedStatic> dialectRegistry; -/// Registry for functions that set dialect hooks. -static llvm::ManagedStatic> - dialectHooksRegistry; - void Dialect::registerDialectAllocator( TypeID typeID, const DialectAllocatorFunction &function) { assert(function && @@ -42,24 +37,11 @@ dialectRegistry->insert({typeID, function}); } -/// Registers a function to set specific hooks for a specific dialect, typically -/// used through the DialectHooksRegistration template. -void DialectHooks::registerDialectHooksSetter( - TypeID typeID, const DialectHooksSetter &function) { - assert( - function && - "Attempting to register an empty dialect hooks initialization function"); - - dialectHooksRegistry->insert({typeID, function}); -} - /// Registers all dialects and hooks from the global registries with the /// specified MLIRContext. void mlir::registerAllDialects(MLIRContext *context) { for (const auto &it : *dialectRegistry) it.second(context); - for (const auto &it : *dialectHooksRegistry) - it.second(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ConstantFoldInterfaces.h" #include using namespace mlir; @@ -570,11 +571,12 @@ if (!dialect) return failure(); - SmallVector constants; - if (failed(dialect->constantFoldHook(this, operands, constants))) + auto *interface = + dialect->getRegisteredInterface(); + if (!interface) return failure(); - results.assign(constants.begin(), constants.end()); - return success(); + + return interface->constantFold(this, operands, results); } /// Emit an error with the op name prefixed, like "'dim' op " which is