diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -61,6 +61,13 @@ struct ToyInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; + /// This hook checks to see if the given callable operation is legal to inline + /// into the given call. For Toy this hook can simply return true, as the Toy + /// Call operation is always inlinable. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// This hook checks to see if the given operation is legal to inline into the /// given region. For Toy this hook can simply return true, as all Toy /// operations are inlinable. diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -34,6 +34,11 @@ // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within toy can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -34,6 +34,11 @@ // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within toy can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -34,6 +34,11 @@ // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within toy can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -35,6 +35,11 @@ // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within toy can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -323,11 +323,20 @@ template AttrClass getAttrOfType(Identifier name) { return getAttr(name).dyn_cast_or_null(); } - template AttrClass getAttrOfType(StringRef name) { return getAttr(name).dyn_cast_or_null(); } + /// Return true if the operation has an attribute with the provided name, + /// false otherwise. + bool hasAttr(Identifier name) { return static_cast(getAttr(name)); } + bool hasAttr(StringRef name) { return static_cast(getAttr(name)); } + template + bool hasAttrOfType(NameT &&name) { + return static_cast( + getAttrOfType(std::forward(name))); + } + /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. void setAttr(Identifier name, Attribute value) { attrs.set(name, value); } diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -47,6 +47,14 @@ // Analysis Hooks //===--------------------------------------------------------------------===// + /// Returns true if the given operation 'callable', that implements the + /// 'CallableOpInterface', can be inlined into the position given call + /// operation 'call', that is registered to the current dialect and implements + /// the `CallOpInterface`. + virtual bool isLegalToInline(Operation *call, Operation *callable) const { + return false; + } + /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. /// 'valueMapping' contains any remapped values from within the 'src' region. @@ -146,6 +154,7 @@ // Analysis Hooks //===--------------------------------------------------------------------===// + virtual bool isLegalToInline(Operation *call, Operation *callable) const; virtual bool isLegalToInline(Region *dest, Region *src, BlockAndValueMapping &valueMapping) const; virtual bool isLegalToInline(Operation *op, Region *dest, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -56,6 +56,11 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; + /// All call operations within SPIRV can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. bool isLegalToInline(Region *dest, Region *src, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -46,6 +46,11 @@ // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within standard ops can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within standard ops can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -57,6 +57,12 @@ // InlinerInterface //===----------------------------------------------------------------------===// +bool InlinerInterface::isLegalToInline(Operation *call, + Operation *callable) const { + auto *handler = getInterfaceFor(call); + return handler ? handler->isLegalToInline(call, callable) : false; +} + bool InlinerInterface::isLegalToInline( Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { // Regions can always be inlined into functions. @@ -352,6 +358,10 @@ castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); } + // Check that it is legal to inline the callable into the call. + if (!interface.isLegalToInline(call, callable)) + return cleanupState(); + // Attempt to inline the call. if (failed(inlineRegion(interface, src, call, mapper, callResults, callableResultTypes, call.getLoc(), diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -183,3 +183,9 @@ %res = call_indirect %fn() : () -> i32 return %res : i32 } + +// CHECK-LABEL: func @no_inline_invalid_call +func @no_inline_invalid_call() -> i32 { + %res = "test.conversion_call_op"() { callee=@convert_callee_fn_multiblock, noinline } : () -> (i32) + return %res : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -77,6 +77,10 @@ // Analysis Hooks //===--------------------------------------------------------------------===// + bool isLegalToInline(Operation *call, Operation *callable) const final { + // Don't allow inlining calls that are marked `noinline`. + return !call->hasAttr("noinline"); + } bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { // Inlining into test dialect regions is legal. return true;