diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -19,16 +19,16 @@ namespace { /// Conversion patterns. -class SizeToIndexOpConversion - : public OpConversionPattern { +class FromExtentTensorOpConversion + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(shape::SizeToIndexOp op, ArrayRef operands, + matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - shape::SizeToIndexOpOperandAdaptor transformed(operands); - rewriter.replaceOp(op.getOperation(), transformed.arg()); + shape::FromExtentTensorOpOperandAdaptor transformed(operands); + rewriter.replaceOp(op.getOperation(), transformed.input()); return success(); } }; @@ -47,6 +47,34 @@ } }; +class SizeToIndexOpConversion + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::SizeToIndexOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + shape::SizeToIndexOpOperandAdaptor transformed(operands); + rewriter.replaceOp(op.getOperation(), transformed.arg()); + return success(); + } +}; + +class ToExtentTensorOpConversion + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + shape::ToExtentTensorOpOperandAdaptor transformed(operands); + rewriter.replaceOp(op.getOperation(), transformed.input()); + return success(); + } +}; + /// Type conversions. class ShapeTypeConverter : public TypeConverter { public: @@ -55,6 +83,7 @@ ShapeTypeConverter(MLIRContext *ctx) { // Add default pass-through conversion. addConversion([&](Type type) { return type; }); + addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); }); addConversion([ctx](shape::ShapeType type) { return RankedTensorType::get({ShapedType::kDynamicSize}, @@ -99,8 +128,10 @@ OwningRewritePatternList &patterns, MLIRContext *ctx) { // clang-format off patterns.insert< + FromExtentTensorOpConversion, IndexToSizeOpConversion, - SizeToIndexOpConversion>(ctx); + SizeToIndexOpConversion, + ToExtentTensorOpConversion>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -39,3 +39,26 @@ // CHECK: return %[[SHAPE]] : tensor return %shape : !shape.shape } + +// ----- + +// Lower `to_extent_tensor` operation to no-op. +// CHECK-LABEL: @to_extent_tensor +// CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> tensor +func @to_extent_tensor(%shape : !shape.shape) -> tensor { + // CHECK-NEXT: return %[[SHAPE]] : tensor + %tensor = "shape.to_extent_tensor"(%shape) : (!shape.shape) -> tensor + return %tensor : tensor +} + +// ----- + +// Lower `from_extent_tensor` operation to no-op. +// CHECK-LABEL: @from_extent_tensor +// CHECK-SAME: (%[[TENSOR:.*]]: tensor) -> tensor +func @from_extent_tensor(%tensor : tensor) -> !shape.shape { + // CHECK-NEXT: return %[[TENSOR]] : tensor + %shape = "shape.from_extent_tensor"(%tensor) + : (tensor) -> !shape.shape + return %shape : !shape.shape +}