diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md --- a/mlir/docs/Canonicalization.md +++ b/mlir/docs/Canonicalization.md @@ -51,7 +51,7 @@ * `constant-like` operations are uniqued and hoisted into the entry block of the first parent barrier region. This is a region that is either isolated from above, e.g. the entry block of a function, or one marked as a barrier - via the `shouldMaterializeInto` method on the `OpFolderDialectInterface`. + via the `shouldMaterializeInto` method on the `DialectFoldInterface`. ## Defining Canonicalizations @@ -170,6 +170,10 @@ `Attribute` value returned, but it is important to ensure that the `Attribute` representation of a specific `Type` is consistent. +When the `fold` hook on an operation is not successful, the dialect can +provide a fallback by implementing the `DialectFoldInterface` and overriding +the fold hook. + #### Generating Constants from Attributes When a `fold` method returns an `Attribute` as the result, it signifies that diff --git a/mlir/include/mlir/Interfaces/FoldInterfaces.h b/mlir/include/mlir/Interfaces/FoldInterfaces.h --- a/mlir/include/mlir/Interfaces/FoldInterfaces.h +++ b/mlir/include/mlir/Interfaces/FoldInterfaces.h @@ -15,9 +15,11 @@ namespace mlir { class Attribute; class OpFoldResult; +class Region; -/// Define a fold interface to allow for dialects to opt-in specific -/// folding for operations they define. +/// Define a fold interface to allow for dialects to control specific aspects +/// of the folding behavior for operations they define. +>>>>>>> bca3b8abb27c... Merge OpFolderDialectInterface with DialectFoldInterface (NFC) class DialectFoldInterface : public DialectInterface::Base { public: @@ -33,6 +35,13 @@ SmallVectorImpl &results) const { return failure(); } + + /// Registered hook to check if the given region, which is attached to an + /// operation that is *not* isolated from above, should be used when + /// materializing constants. The folder will generally materialize constants + /// into the top-level isolated region, this allows for materializing into a + /// lower level ancestor region if it is more profitable/correct. + virtual bool shouldMaterializeInto(Region *region) const { return false; } }; } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -17,29 +17,12 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/Interfaces/FoldInterfaces.h" namespace mlir { class Operation; class Value; -//===--------------------------------------------------------------------===// -// Operation Folding Interface -//===--------------------------------------------------------------------===// - -/// This class defines a dialect interface used to assist the operation folder. -/// It provides hooks for materializing and folding operations. -class OpFolderDialectInterface - : public DialectInterface::Base { -public: - OpFolderDialectInterface(Dialect *dialect) : Base(dialect) {} - - /// Registered hook to check if the given region, which is attached to an - /// operation that is *not* isolated from above, should be used when - /// materializing constants. The folder will generally materialize constants - /// into the top-level isolated region, this allows for materializing into a - /// lower level ancestor region if it is more profitable/correct. - virtual bool shouldMaterializeInto(Region *region) const { return false; } -}; //===--------------------------------------------------------------------===// // OperationFolder @@ -153,7 +136,7 @@ DenseMap> referencedDialects; /// A collection of dialect folder interfaces. - DialectInterfaceCollection interfaces; + DialectInterfaceCollection interfaces; }; } // end namespace mlir diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -22,9 +22,9 @@ /// Given an operation, find the parent region that folded constants should be /// inserted into. -static Region *getInsertionRegion( - DialectInterfaceCollection &interfaces, - Block *insertionBlock) { +static Region * +getInsertionRegion(DialectInterfaceCollection &interfaces, + Block *insertionBlock) { while (Region *region = insertionBlock->getParent()) { // Insert in this region for any of the following scenarios: // * The parent is unregistered, or is known to be isolated from above. diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -52,8 +52,8 @@ } }; -struct TestOpFolderDialectInterface : public OpFolderDialectInterface { - using OpFolderDialectInterface::OpFolderDialectInterface; +struct TestDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; /// Registered hook to check if the given region, which is attached to an /// operation that is *not* isolated from above, should be used when @@ -135,7 +135,7 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); - addInterfaces(); addTypes(); allowUnknownOperations();