diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/Parser.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" @@ -24,6 +25,38 @@ using namespace mlir; using namespace mlir::linalg; +//===----------------------------------------------------------------------===// +// LinalgDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { + +struct LinalgInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + // We don't have any special restrictions on what can be inlined into + // destination regions (e.g. while/conditional bodies). Always allow it. + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &valueMapping) const final { + return true; + } + // Operations in LinAlg dialect are always legal to inline. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + // Handle the given inlined terminator by replacing it with a new operation + // as necessary. Required when the region has only one block. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final {} +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// LinalgDialect +//===----------------------------------------------------------------------===// + void mlir::linalg::LinalgDialect::initialize() { addTypes(); addOperations< @@ -34,7 +67,9 @@ #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >(); + addInterfaces(); } + Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const { // Parse the main keyword for the type. StringRef keyword;