diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -441,12 +441,52 @@ } }; +/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the +/// filling value are the same. +struct FoldFillWithPad final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + auto fillOp = padOp.source().getDefiningOp(); + if (!fillOp) + return failure(); + + // We can only fold if the padding value is the same as the original + // filling value. + Value padValue = padOp.getConstantPaddingValue(); + if (!padValue || fillOp.value() != padValue) + return failure(); + + ReifiedRankedShapedTypeDims reifiedShape; + ReifyRankedShapedTypeOpInterface interface = + cast(padOp.getOperation()); + if (failed(interface.reifyResultShapes(rewriter, reifiedShape))) + return rewriter.notifyMatchFailure( + padOp, "failed to reify tensor.pad op result shape"); + + auto oldResultType = padOp.getResultType(); + SmallVector staticShape(oldResultType.getRank(), + ShapedType::kDynamicSize); + auto newInitOp = rewriter.create( + padOp.getLoc(), reifiedShape.front(), staticShape, + oldResultType.getElementType()); + auto newFillOp = + rewriter.create(fillOp.getLoc(), padValue, newInitOp); + rewriter.replaceOpWithNewOp(padOp, oldResultType, + newFillOp.result()); + + return success(); + } +}; + } // namespace void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - FoldFillWithTensorReshape>(context); + results + .add, + FoldFillWithTensorReshape>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -585,3 +585,68 @@ } return } + +// ----- + +// CHECK-LABEL: func @fold_static_pad_fill +// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[INIT:.+]] = linalg.init_tensor [412, 276] : tensor<412x276xf32> +// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]]) +// CHECK: return %[[FILL]] +func @fold_static_pad_fill() -> tensor<412x276xf32> { + %f0 = arith.constant 0.0 : f32 + %init = linalg.init_tensor [400, 273] : tensor<400x273xf32> + %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32> + %pad = tensor.pad %fill low[4, 1] high[8, 2] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %f0 : f32 + } : tensor<400x273xf32> to tensor<412x276xf32> + return %pad : tensor<412x276xf32> +} + +// ----- + +// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 9)> +// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 10)> +// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 + 23)> +// CHECK: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 32)> + +// CHECK: func @fold_dynamic_pad_fill +// CHECK-SAME: %[[SRC:.+]]: tensor<8x?x16x32xf32>, %[[LOW0:.+]]: index, %[[LOW3:.+]]: index, %[[HIGH2:.+]]: index, %[[HIGH3:.+]]: index + +// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[OF:.+]] = linalg.fill(%[[F0]], %[[SRC]]) : f32, tensor<8x?x16x32xf32> +// CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]] +// CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32> +// CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]] +// CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]] +// CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : tensor +// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]]) +// CHECK: return %[[FILL]] +func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor { + %f0 = arith.constant 0.0 : f32 + %fill = linalg.fill(%f0, %init) : f32, tensor<8x?x16x32xf32> -> tensor<8x?x16x32xf32> + %pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %f0 : f32 + } : tensor<8x?x16x32xf32> to tensor + return %pad : tensor +} + +// ----- + +// CHECK-LABEL: func @no_fold_pad_fill_value_mismatch +func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> { + %f0 = arith.constant 0.0 : f32 + %f1 = arith.constant 1.0 : f32 + %init = linalg.init_tensor [400, 273] : tensor<400x273xf32> + %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32> + // CHECK: tensor.pad + %pad = tensor.pad %fill low[4, 1] high[8, 2] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %f1 : f32 + } : tensor<400x273xf32> to tensor<412x276xf32> + return %pad : tensor<412x276xf32> +}