diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1063,11 +1063,29 @@ return success(); } }; + +// Fold CastOp into PadTensorOp when adding static information. +struct FoldSourceTensorCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padTensorOp, + PatternRewriter &rewriter) const override { + auto castOp = padTensorOp.source().getDefiningOp(); + if (!tensor::canFoldIntoConsumerOp(castOp)) + return failure(); + + rewriter.updateRootInPlace(padTensorOp, [&]() { + padTensorOp.sourceMutable().assign(castOp.source()); + }); + return success(); + } +}; } // namespace void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } /// Return the padding value of the PadTensorOp if it constant. In this context, diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -772,6 +772,22 @@ // ----- +// CHECK-LABEL: func @fold_pad_tensor_source_cast( +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32> +// CHECK-NOT: tensor.cast +// CHECK: %[[RESULT:.*]] = linalg.pad_tensor %[[ARG0]] +func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> { + %cst = constant 0.0 : f32 + %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor + %1 = linalg.pad_tensor %0 low[0, 0] high[0, 1] { + ^bb0(%arg1: index, %arg2: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<4x4xf32> + return %1 : tensor<4x4xf32> +} + +// ----- + // CHECK-LABEL: func @pad_static_zero_cast( // CHECK-SAME: %[[ARG0:.*]]: tensor // CHECK-NOT: linalg.pad_tensor