diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -51,6 +51,20 @@ } }; +class RankOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::RankOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + shape::RankOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp(op.getOperation(), transformed.shape(), + 0); + return success(); + } +}; + /// Type conversions. class ShapeTypeConverter : public TypeConverter { public: @@ -107,7 +121,8 @@ patterns.insert< BinaryOpConversion, BinaryOpConversion, - ConstSizeOpConverter>(ctx); + ConstSizeOpConverter, + RankOpConverter>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -86,3 +86,17 @@ } // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: return %[[C1]] : index + +// ----- + +// Convert `rank` to `dim` of the first dimension. +// CHECK-LABEL: @rank +// CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> index +func @rank(%shape : !shape.shape) -> !shape.size { + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]] + // CHECK-DAG: return %[[RESULT]] : index + %rank = shape.rank %shape + return %rank : !shape.size +} +