diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -64,14 +64,15 @@ /// This hook checks to see if the given callable operation is legal to inline /// into the given call. For Toy this hook can simply return true, as the Toy /// Call operation is always inlinable. - bool isLegalToInline(Operation *call, Operation *callable) const final { + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { return true; } /// This hook checks to see if the given operation is legal to inline into the /// given region. For Toy this hook can simply return true, as all Toy /// operations are inlinable. - bool isLegalToInline(Operation *, Region *, + bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -35,12 +35,13 @@ //===--------------------------------------------------------------------===// /// All call operations within toy can be inlined. - bool isLegalToInline(Operation *call, Operation *callable) const final { + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { return true; } /// All operations within toy can be inlined. - bool isLegalToInline(Operation *, Region *, + bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -35,12 +35,13 @@ //===--------------------------------------------------------------------===// /// All call operations within toy can be inlined. - bool isLegalToInline(Operation *call, Operation *callable) const final { + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { return true; } /// All operations within toy can be inlined. - bool isLegalToInline(Operation *, Region *, + bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -35,12 +35,13 @@ //===--------------------------------------------------------------------===// /// All call operations within toy can be inlined. - bool isLegalToInline(Operation *call, Operation *callable) const final { + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { return true; } /// All operations within toy can be inlined. - bool isLegalToInline(Operation *, Region *, + bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -36,12 +36,13 @@ //===--------------------------------------------------------------------===// /// All call operations within toy can be inlined. - bool isLegalToInline(Operation *call, Operation *callable) const final { + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { return true; } /// All operations within toy can be inlined. - bool isLegalToInline(Operation *, Region *, + bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -50,27 +50,36 @@ /// Returns true if the given operation 'callable', that implements the /// 'CallableOpInterface', can be inlined into the position given call /// operation 'call', that is registered to the current dialect and implements - /// the `CallOpInterface`. - virtual bool isLegalToInline(Operation *call, Operation *callable) const { + /// the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the + /// given 'callable' is set to be cloned during the inlining process, or false + /// if the region is set to be moved in-place(i.e. no duplicates would be + /// created). + virtual bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const { return false; } /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. - /// 'valueMapping' contains any remapped values from within the 'src' region. - /// This can be used to examine what values will replace entry arguments into - /// the 'src' region for example. - virtual bool isLegalToInline(Region *dest, Region *src, + /// 'wouldBeCloned' is set to true if the given 'src' region is set to be + /// cloned during the inlining process, or false if the region is set to be + /// moved in-place(i.e. no duplicates would be created). 'valueMapping' + /// contains any remapped values from within the 'src' region. This can be + /// used to examine what values will replace entry arguments into the 'src' + /// region for example. + virtual bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const { return false; } /// Returns true if the given operation 'op', that is registered to this /// dialect, can be inlined into the given region, false otherwise. - /// 'valueMapping' contains any remapped values from within the 'src' region. - /// This can be used to examine what values may potentially replace the - /// operands to 'op'. - virtual bool isLegalToInline(Operation *op, Region *dest, + /// 'wouldBeCloned' is set to true if the given 'op' is set to be cloned + /// during the inlining process, or false if the operation is set to be moved + /// in-place(i.e. no duplicates would be created). 'valueMapping' contains any + /// remapped values from within the 'src' region. This can be used to examine + /// what values may potentially replace the operands to 'op'. + virtual bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const { return false; } @@ -154,10 +163,11 @@ // Analysis Hooks //===--------------------------------------------------------------------===// - virtual bool isLegalToInline(Operation *call, Operation *callable) const; - virtual bool isLegalToInline(Region *dest, Region *src, + virtual bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const; + virtual bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const; - virtual bool isLegalToInline(Operation *op, Region *dest, + virtual bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const; virtual bool shouldAnalyzeRecursively(Operation *op) const; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -41,7 +41,7 @@ /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. - bool isLegalToInline(Region *dest, Region *src, + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { // Conservatively don't allow inlining into affine structures. return false; @@ -49,7 +49,7 @@ /// Returns true if the given operation 'op', that is registered to this /// dialect, can be inlined into the given region, false otherwise. - bool isLegalToInline(Operation *op, Region *region, + bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { // Always allow inlining affine operations into the top-level region of a // function. There are some edge cases when inlining *into* affine diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -36,12 +36,12 @@ // We don't have any special restrictions on what can be inlined into // destination regions (e.g. while/conditional bodies). Always allow it. - bool isLegalToInline(Region *dest, Region *src, + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { return true; } // Operations in Linalg dialect are always legal to inline. - bool isLegalToInline(Operation *, Region *, + bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -24,13 +24,13 @@ using DialectInlinerInterface::DialectInlinerInterface; // We don't have any special restrictions on what can be inlined into // destination regions (e.g. while/conditional bodies). Always allow it. - bool isLegalToInline(Region *dest, Region *src, + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { return true; } // Operations in scf dialect are always legal to inline since they are // pure. - bool isLegalToInline(Operation *, Region *, + bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -57,13 +57,14 @@ using DialectInlinerInterface::DialectInlinerInterface; /// All call operations within SPIRV can be inlined. - bool isLegalToInline(Operation *call, Operation *callable) const final { + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { return true; } /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. - bool isLegalToInline(Region *dest, Region *src, + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &) const final { // Return true here when inlining into spv.func, spv.selection, and // spv.loop operations. @@ -74,7 +75,7 @@ /// Returns true if the given operation 'op', that is registered to this /// dialect, can be inlined into the region 'dest' that is attached to an /// operation registered to the current dialect. - bool isLegalToInline(Operation *op, Region *dest, + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, BlockAndValueMapping &) const final { // TODO: Enable inlining structured control flows with return. if ((isa(op)) && diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -72,7 +72,7 @@ // Returns true if the given region 'src' can be inlined into the region // 'dest' that is attached to an operation registered to the current dialect. - bool isLegalToInline(Region *dest, Region *src, + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &) const final { return true; } @@ -80,7 +80,7 @@ // Returns true if the given operation 'op', that is registered to this // dialect, can be inlined into the region 'dest' that is attached to an // operation registered to the current dialect. - bool isLegalToInline(Operation *op, Region *dest, + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, BlockAndValueMapping &) const final { return true; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -47,12 +47,13 @@ //===--------------------------------------------------------------------===// /// All call operations within standard ops can be inlined. - bool isLegalToInline(Operation *call, Operation *callable) const final { + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { return true; } /// All operations within standard ops can be inlined. - bool isLegalToInline(Operation *, Region *, + bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -57,26 +57,31 @@ // InlinerInterface //===----------------------------------------------------------------------===// -bool InlinerInterface::isLegalToInline(Operation *call, - Operation *callable) const { - auto *handler = getInterfaceFor(call); - return handler ? handler->isLegalToInline(call, callable) : false; +bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const { + if (auto *handler = getInterfaceFor(call)) + return handler->isLegalToInline(call, callable, wouldBeCloned); + return false; } bool InlinerInterface::isLegalToInline( - Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { + Region *dest, Region *src, bool wouldBeCloned, + BlockAndValueMapping &valueMapping) const { // Regions can always be inlined into functions. if (isa(dest->getParentOp())) return true; - auto *handler = getInterfaceFor(dest->getParentOp()); - return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; + if (auto *handler = getInterfaceFor(dest->getParentOp())) + return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping); + return false; } bool InlinerInterface::isLegalToInline( - Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const { - auto *handler = getInterfaceFor(op); - return handler ? handler->isLegalToInline(op, dest, valueMapping) : false; + Operation *op, Region *dest, bool wouldBeCloned, + BlockAndValueMapping &valueMapping) const { + if (auto *handler = getInterfaceFor(op)) + return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping); + return false; } bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { @@ -103,12 +108,13 @@ /// Utility to check that all of the operations within 'src' can be inlined. static bool isLegalToInline(InlinerInterface &interface, Region *src, - Region *insertRegion, + Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping) { for (auto &block : *src) { for (auto &op : block) { // Check this operation. - if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) { + if (!interface.isLegalToInline(&op, insertRegion, + shouldCloneInlinedRegion, valueMapping)) { LLVM_DEBUG({ llvm::dbgs() << "* Illegal to inline because of op: "; op.dump(); @@ -119,7 +125,7 @@ if (interface.shouldAnalyzeRecursively(&op) && llvm::any_of(op.getRegions(), [&](Region ®ion) { return !isLegalToInline(interface, ®ion, insertRegion, - valueMapping); + shouldCloneInlinedRegion, valueMapping); })) return false; } @@ -156,8 +162,10 @@ Region *insertRegion = insertBlock->getParent(); // Check that the operations within the source region are valid to inline. - if (!interface.isLegalToInline(insertRegion, src, mapper) || - !isLegalToInline(interface, src, insertRegion, mapper)) + if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion, + mapper) || + !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion, + mapper)) return failure(); // Split the insertion block. @@ -359,7 +367,7 @@ } // Check that it is legal to inline the callable into the call. - if (!interface.isLegalToInline(call, callable)) + if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion)) return cleanupState(); // Attempt to inline the call. diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -77,15 +77,17 @@ // Analysis Hooks //===--------------------------------------------------------------------===// - bool isLegalToInline(Operation *call, Operation *callable) const final { + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { // Don't allow inlining calls that are marked `noinline`. return !call->hasAttr("noinline"); } - bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { + bool isLegalToInline(Region *, Region *, bool, + BlockAndValueMapping &) const final { // Inlining into test dialect regions is legal. return true; } - bool isLegalToInline(Operation *, Region *, + bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; }