diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -19,11 +19,46 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/InliningUtils.h" using namespace mlir; using namespace mlir::scf; //===----------------------------------------------------------------------===// +// SCFDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct SCFInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + // We don't have any special restrictions on what can be inlined into + // destination regions (e.g. while/conditional bodies). Always allow it. + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &valueMapping) const final { + return true; + } + // Operations in scf dialect are always legal to inline since they are + // pure. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + // Handle the given inlined terminator by replacing it with a new operation + // as necessary. Required when the region has only one block. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + auto retValOp = dyn_cast(op); + if (!retValOp) + return; + + for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) { + std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue)); + } + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// // SCFDialect //===----------------------------------------------------------------------===// @@ -33,6 +68,7 @@ #define GET_OP_LIST #include "mlir/Dialect/SCF/SCFOps.cpp.inc" >(); + addInterfaces(); } /// Default callback for IfOp builders. Inserts a yield without arguments.