diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -90,6 +90,30 @@ } }; +class GetExtentOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GetExtentOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + GetExtentOp::Adaptor transformed(operands); + + // Derive shape extent directly from shape origin if possible. + // This circumvents the necessity to materialize the shape in memory. + auto shapeOfOp = op.shape().getDefiningOp(); + if (shapeOfOp) { + rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), + transformed.dim()); + return success(); + } + + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexType(), transformed.shape(), + ValueRange{transformed.dim()}); + return success(); + } +}; + class RankOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -161,6 +185,7 @@ BinaryOpConversion, BinaryOpConversion, ConstSizeOpConverter, + GetExtentOpConverter, RankOpConverter, ShapeOfOpConversion>(ctx); // clang-format on