diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -968,11 +968,38 @@ return success(); } }; + +// Canonicalize +// ``` +// %0 = shape.shape_of %arg : tensor -> tensor<3xindex> +// %1 = tensor.cast %0 : tensor<3xindex> to tensor +// ``` +// to +// ``` +// %1 = shape.shape_of %arg : tensor -> tensor +// ``` +struct ShapeOfCastedExtentTensor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + auto ty = op.getType().dyn_cast(); + if (!ty || ty.getRank() != 1) + return failure(); + + auto shapeOfOp = op.source().getDefiningOp(); + if (!shapeOfOp) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), shapeOfOp.arg()); + return success(); + } +}; } // namespace void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1131,3 +1131,29 @@ "use"(%0) : (tensor) -> () return } + +// ----- + +// CHECK-LABEL: @casted_extent_tensor +// CHECK-SAME: (%[[ARG:.*]]: tensor) -> tensor +func @casted_extent_tensor(%arg : tensor) + -> tensor { + // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor + // CHECK: return %[[RESULT]] : tensor + %0 = shape.shape_of %arg : tensor -> tensor<3xindex> + %1 = tensor.cast %0 : tensor<3xindex> to tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @casted_extent_tensor +// CHECK-SAME: (%[[ARG:.*]]: tensor) -> tensor<3xindex> +func @casted_extent_tensor(%arg : tensor) + -> tensor<3xindex> { + // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor<3xindex> + // CHECK: return %[[RESULT]] : tensor<3xindex> + %0 = shape.shape_of %arg : tensor -> tensor + %1 = tensor.cast %0 : tensor to tensor<3xindex> + return %1 : tensor<3xindex> +}