diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -90,6 +90,20 @@ } }; +class RankOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::RankOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + shape::RankOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp(op.getOperation(), transformed.shape(), + 0); + return success(); + } +}; + /// Type conversions. class ShapeTypeConverter : public TypeConverter { public: @@ -147,6 +161,7 @@ BinaryOpConversion, BinaryOpConversion, ConstSizeOpConverter, + RankOpConverter, ShapeOfOpConversion>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -86,7 +86,6 @@ } // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: return %[[C1]] : index - // ----- // Lower `shape_of` for statically shaped tensor. @@ -115,3 +114,16 @@ %shape = shape.shape_of %arg : tensor<1x5x?xf32> return } + +// ----- + +// Convert `rank` to `dim` of the first dimension. +// CHECK-LABEL: @rank +// CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> index +func @rank(%shape : !shape.shape) -> !shape.size { + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]] + // CHECK-DAG: return %[[RESULT]] : index + %rank = shape.rank %shape + return %rank : !shape.size +}