diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -56,6 +56,10 @@ // Add default pass-through conversion. addConversion([&](Type type) { return type; }); addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); }); + addConversion([ctx](shape::ShapeType type) { + return RankedTensorType::get({ShapedType::kDynamicSize}, + IndexType::get(ctx)); + }); } }; diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -29,3 +29,13 @@ %size = shape.index_to_size %index return %size : !shape.size } + +// ----- + +// Convert `shape` to `tensor` type. +// CHECK-LABEL: @shape_id +// CHECK-SAME: (%[[SHAPE:.*]]: tensor) +func @shape_id(%shape : !shape.shape) -> !shape.shape { + // CHECK: return %[[SHAPE]] : tensor + return %shape : !shape.shape +}