diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/raw_ostream.h" @@ -59,6 +60,32 @@ return success(); } +//===----------------------------------------------------------------------===// +// InlinerInterface +//===----------------------------------------------------------------------===// + +namespace { +/// This class defines the interface for inlining shape dialect ops. +struct ShapeInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + // Returns true if the given region 'src' can be inlined into the region + // 'dest' that is attached to an operation registered to the current dialect. + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &) const final { + return true; + } + + // Returns true if the given operation 'op', that is registered to this + // dialect, can be inlined into the region 'dest' that is attached to an + // operation registered to the current dialect. + bool isLegalToInline(Operation *op, Region *dest, + BlockAndValueMapping &) const final { + return true; + } +}; +} // namespace + void ShapeDialect::initialize() { addOperations< #define GET_OP_LIST @@ -66,6 +93,7 @@ >(); addTypes(); + addInterfaces(); // Allow unknown operations during prototyping and testing. As the dialect is // still evolving it makes it simple to start with an unregistered ops and // try different variants before actually defining the op. @@ -640,11 +668,14 @@ shapeOfOp.arg().getType().dyn_cast(); if (!rankedTensorType) return failure(); - assert(op.getType().isa() && - "expected `rank(shape_of( ... )]` based on a shaped argument to " - "yield an index type"); int64_t rank = rankedTensorType.getRank(); - rewriter.replaceOpWithNewOp(op.getOperation(), rank); + if (op.getType().isa()) { + rewriter.replaceOpWithNewOp(op.getOperation(), rank); + } else if (op.getType().isa()) { + rewriter.replaceOpWithNewOp(op.getOperation(), rank); + } else { + return failure(); + } return success(); } }; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -624,6 +624,18 @@ // ----- +// Canonicalize `rank` when shape is derived from ranked tensor. +// CHECK-LABEL: @canonicalize_rank +func @canonicalize_rank_size(%arg : tensor<1x2x?xf32>) -> !shape.size { + // CHECK: %[[RESULT:.*]] = shape.const_size 3 + // CHECK: return %[[RESULT]] : !shape.size + %shape = shape.shape_of %arg : tensor<1x2x?xf32> -> !shape.shape + %rank = shape.rank %shape : !shape.shape -> !shape.size + return %rank : !shape.size +} + +// ----- + // Do not canonicalize `rank` when shape is derived from unranked tensor. // CHECK-LABEL: @dont_canonicalize_rank // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> index