diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -77,6 +77,27 @@ } }; +class ConstShapeOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConstShapeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + SmallVector extentOperands; + for (auto extent : op.shape()) { + extentOperands.push_back( + rewriter.create(loc, extent.getLimitedValue())); + } + Value tensor = rewriter.create(loc, extentOperands); + Type indexTy = rewriter.getIndexType(); + Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); + rewriter.replaceOpWithNewOp(op, tensor, resultTy); + return success(); + } +}; + class ConstSizeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -160,6 +181,7 @@ patterns.insert< BinaryOpConversion, BinaryOpConversion, + ConstShapeOpConverter, ConstSizeOpConverter, RankOpConverter, ShapeOfOpConversion>(ctx); diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -127,3 +127,19 @@ %rank = shape.rank %shape return %rank : !shape.size } + +// ----- + +// Lower `const_shape` to `tensor_from_elements`. +// CHECK-LABEL: @const_shape +// CHECK-SAME: () -> tensor +func @const_shape() -> !shape.shape { + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK: %[[C3:.*]] = constant 3 : index + // CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> + // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor + // CHECK: return %[[RESULT]] : tensor + %shape = shape.const_shape [1, 2, 3] + return %shape : !shape.shape +}