diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/Parser.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" @@ -24,6 +25,38 @@ using namespace mlir; using namespace mlir::linalg; +//===----------------------------------------------------------------------===// +// LinalgDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { + +struct LinalgInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + // We don't have any special restrictions on what can be inlined into + // destination regions (e.g. while/conditional bodies). Always allow it. + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &valueMapping) const final { + return true; + } + // Operations in Linalg dialect are always legal to inline. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + // Handle the given inlined terminator by replacing it with a new operation + // as necessary. Required when the region has only one block. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final {} +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// LinalgDialect +//===----------------------------------------------------------------------===// + void mlir::linalg::LinalgDialect::initialize() { addTypes(); addOperations< @@ -34,7 +67,9 @@ #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >(); + addInterfaces(); } + Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const { // Parse the main keyword for the type. StringRef keyword; diff --git a/mlir/test/Dialect/Linalg/inlining.mlir b/mlir/test/Dialect/Linalg/inlining.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/inlining.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s -inline | FileCheck %s + +// These tests verify that regions with operations from Lingalg dialect +// can be inlined. + +#accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)> +] + +#trait = { + args_in = 1, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel"] +} + +func @inline_into(%arg0: memref) { + // CHECK: linalg.generic + call @inlined_fn(%arg0) : (memref) -> () + return +} + +func @inlined_fn(%arg0: memref) { + // CHECK: linalg.generic + linalg.generic #trait %arg0, %arg0 { + ^bb(%0 : f32, %1 : f32) : + linalg.yield %0 : f32 + } : memref, memref + return +}