diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -928,6 +928,36 @@ } }; +// Fold CastOp into CollapseShapeOp when adding static information. +struct FoldCollapseOfCastOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp, + PatternRewriter &rewriter) const override { + auto castOp = collapseShapeOp.getSrc().getDefiningOp(); + if (!tensor::canFoldIntoConsumerOp(castOp)) + return failure(); + + RankedTensorType srcType = + castOp.getSource().getType().cast(); + RankedTensorType newResultType = computeTensorReshapeCollapsedType( + srcType, collapseShapeOp.getReassociationMaps()); + + if (newResultType == collapseShapeOp.getResultType()) { + rewriter.updateRootInPlace(collapseShapeOp, [&]() { + collapseShapeOp.getSrcMutable().assign(castOp.getSource()); + }); + } else { + auto newOp = rewriter.create( + collapseShapeOp.getLoc(), newResultType, castOp.getSource(), + collapseShapeOp.getReassociation()); + rewriter.replaceOpWithNewOp( + collapseShapeOp, collapseShapeOp.getResultType(), newOp); + } + return success(); + } +}; + } // namespace void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -940,10 +970,12 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeCollapseOfExpandOp, - FoldReshapeWithConstant, - FoldReshapeWithFromElements>(context); + results + .add, + ComposeCollapseOfExpandOp, + FoldReshapeWithConstant, + FoldReshapeWithFromElements, FoldCollapseOfCastOp>( + context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -673,6 +673,20 @@ // ----- +// CHECK-LABEL: func.func @collapse_of_cast( +// CHECK-SAME: %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor { +// CHECK-NEXT: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32> +// CHECK-NEXT %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor +// CHECK-NEXT return %[[CAST]] : tensor +func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor { + %0 = tensor.cast %t : tensor<8x12x32xf32> to tensor + %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor into tensor + %2 = tensor.cast %1 : tensor to tensor + return %2 : tensor +} + +// ----- + func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<12x4xf32> into tensor<3x4x4xf32>