diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -90,6 +90,29 @@ } }; +class GetExtentOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GetExtentOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + GetExtentOp::Adaptor transformed(operands); + + // Derive shape extent directly from shape origin if possible. + // This circumvents the necessity to materialize the shape in memory. + if (auto shapeOfOp = op.shape().getDefiningOp()) { + rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), + transformed.dim()); + return success(); + } + + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexType(), transformed.shape(), + ValueRange{transformed.dim()}); + return success(); + } +}; + class RankOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -161,6 +184,7 @@ BinaryOpConversion, BinaryOpConversion, ConstSizeOpConverter, + GetExtentOpConverter, RankOpConverter, ShapeOfOpConversion>(ctx); // clang-format on diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td @@ -19,20 +19,3 @@ (Shape_SizeToIndexOp $arg), (replaceWithValue $arg)>; -// Derive shape extent directly from shape origin if possible. -// This circumvents the necessity to materialize the shape in memory. -def GetExtentShapeOfConversion : Pat< - (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx), - (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))), - [], - (addBenefit 10)>; -def GetExtentFromExtentTensorConversion : Pattern< - (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx), - [ - (Shape_SizeToIndexOp:$std_idx $idx), - (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)), - (Shape_IndexToSizeOp $std_result) - ], - [], - (addBenefit 10)>; -