diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -38,6 +38,45 @@ } }; +class ShapeOfOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ShapeOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ShapeOfOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + auto tensorVal = transformed.arg(); + auto tensorTy = tensorVal.getType(); + + // For unranked tensors `shape_of` lowers to `scf` and the pattern can be + // found in the corresponding pass. + if (tensorTy.isa()) + return failure(); + + // Build values for individual dimensions. + SmallVector dimValues; + auto rankedTensorTy = tensorTy.cast(); + int64_t rank = rankedTensorTy.getRank(); + for (int64_t i = 0; i < rank; i++) { + if (rankedTensorTy.isDynamicDim(i)) { + auto dimVal = rewriter.create(loc, tensorVal, i); + dimValues.push_back(dimVal); + } else { + int64_t dim = rankedTensorTy.getDimSize(i); + auto dimVal = rewriter.create(loc, dim); + dimValues.push_back(dimVal); + } + } + + // Materialize shape as ranked tensor. + rewriter.replaceOpWithNewOp(op.getOperation(), + dimValues); + return success(); + } +}; + class ConstSizeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -107,7 +146,8 @@ patterns.insert< BinaryOpConversion, BinaryOpConversion, - ConstSizeOpConverter>(ctx); + ConstSizeOpConverter, + ShapeOfOpConversion>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -86,3 +86,32 @@ } // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: return %[[C1]] : index + +// ----- + +// Lower `shape_of` for statically shaped tensor. +// CHECK-LABEL: @shape_of_stat +// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) +func @shape_of_stat(%arg : tensor<1x2x3xf32>) { + // CHECK-DAG: %[[C1:.*]] = constant 1 : index + // CHECK-DAG: %[[C2:.*]] = constant 2 : index + // CHECK-DAG: %[[C3:.*]] = constant 3 : index + // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> + %shape = shape.shape_of %arg : tensor<1x2x3xf32> + return +} + +// ----- + +// Lower `shape_of` for dynamically shaped tensor. +// CHECK-LABEL: @shape_of_dyn +// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>) +func @shape_of_dyn(%arg : tensor<1x5x?xf32>) { + // CHECK-DAG: %[[C1:.*]] = constant 1 : index + // CHECK-DAG: %[[C5:.*]] = constant 5 : index + // CHECK-DAG: %[[C2:.*]] = constant 2 : index + // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> + // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> + %shape = shape.shape_of %arg : tensor<1x5x?xf32> + return +}