diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1027,9 +1027,15 @@ } void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { - Type type = arg.getType().isa() - ? (Type)getExtentTensorType(builder.getContext()) - : (Type)builder.getType(); + Type type; + Type indexTy = builder.getIndexType(); + if (auto shapedTy = arg.getType().dyn_cast()) { + type = RankedTensorType::get( + {shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize}, + indexTy); + } else { + type = builder.getType(); + } return ShapeOfOp::build(builder, result, type, arg); } @@ -1048,11 +1054,40 @@ return success(); } }; + +// Canonicalize +// ``` +// %0 = shape.shape_of %arg : tensor -> tensor<3xindex> +// %1 = tensor.cast %0 : tensor<3xindex> to tensor +// ``` +// to +// ``` +// %1 = shape.shape_of %arg : tensor -> tensor +// ``` +struct ShapeOfCastedToDynamicExtentTensor + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + Type extentTensorTy = getExtentTensorType(op.getContext()); + if (op.getType() != extentTensorTy) + return failure(); + + auto shapeOfOp = op.source().getDefiningOp(); + if (!shapeOfOp) + return failure(); + + rewriter.replaceOpWithNewOp(op, extentTensorTy, shapeOfOp.arg()); + return success(); + } +}; } // namespace void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns.insert( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1169,3 +1169,16 @@ return %results#0, %results#1, %results#2 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32> } + +// ----- + +// CHECK-LABEL: @casted_dynamic_extent_tensor +// CHECK-SAME: (%[[ARG:.*]]: tensor) -> tensor +func @casted_dynamic_extent_tensor(%arg : tensor) + -> tensor { + // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor + // CHECK: return %[[RESULT]] : tensor + %0 = shape.shape_of %arg : tensor -> tensor<3xindex> + %1 = tensor.cast %0 : tensor<3xindex> to tensor + return %1 : tensor +}