diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/StandardTypes.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Transforms/InliningUtils.h" using namespace mlir; using namespace mlir::shape; @@ -59,6 +60,32 @@ return success(); } +//===----------------------------------------------------------------------===// +// InlinerInterface +//===----------------------------------------------------------------------===// + +namespace { +/// This class defines the interface for inlining within the TFR dialect. +struct ShapeInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + // Returns true if the given region 'src' can be inlined into the region + // 'dest' that is attached to an operation registered to the current dialect. + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &) const final { + return true; + } + + // Returns true if the given operation 'op', that is registered to this + // dialect, can be inlined into the region 'dest' that is attached to an + // operation registered to the current dialect. + bool isLegalToInline(Operation *op, Region *dest, + BlockAndValueMapping &) const final { + return true; + } +}; +} // namespace + void ShapeDialect::initialize() { addOperations< #define GET_OP_LIST @@ -66,6 +93,7 @@ >(); addTypes(); + addInterfaces(); // Allow unknown operations during prototyping and testing. As the dialect is // still evolving it makes it simple to start with an unregistered ops and // try different variants before actually defining the op. @@ -640,11 +668,13 @@ shapeOfOp.arg().getType().dyn_cast(); if (!rankedTensorType) return failure(); - assert(op.getType().isa() && - "expected `rank(shape_of( ... )]` based on a shaped argument to " - "yield an index type"); - int64_t rank = rankedTensorType.getRank(); - rewriter.replaceOpWithNewOp(op.getOperation(), rank); + if (op.getType().isa()) { + rewriter.replaceOpWithNewOp(op.getOperation(), rank); + } else if (op.getType().isa()){ + rewriter.replaceOpWithNewOp(op.getOperation(), rank); + } else { + return failure(); + } return success(); } };