diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -446,6 +446,46 @@ *packInfo); } +/// Folds pack(fill) into a single fill op if +/// 1. The pack op does not have padding value, or +/// 2. The filled value and padding value are the same. +static FailureOr +foldFillPackIntoFillOp(RewriterBase &rewriter, tensor::PackOp packOp, + ControlPropagationFn controlFn) { + auto fillOp = packOp.getSource().getDefiningOp(); + if (!fillOp) + return failure(); + + // User controlled propagation function. + if (!controlFn(fillOp)) + return failure(); + + if (auto paddingValue = packOp.getPaddingValue()) + if (!isEqualConstantIntOrValue(paddingValue, fillOp.value())) + return failure(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(fillOp); + + Value packOpDest = packOp.getDest(); + if (!packOpDest.hasOneUse()) + return failure(); + if (auto emptyOp = packOpDest.getDefiningOp()) { + packOpDest = tensor::PackOp::createDestinationTensor( + rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(), + packOp.getMixedTiles(), packOp.getInnerDimsPos(), + packOp.getOuterDimsPerm()); + } else { + DominanceInfo dom(fillOp); + if (!dom.properlyDominates(packOpDest, fillOp)) + return failure(); + } + + Value fillDest = packOpDest; + return clone(rewriter, fillOp, packOpDest.getType(), + {fillOp.value(), fillDest}); +} + /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. struct BubbleUpPackOpThroughGenericOpPattern : public OpRewritePattern { @@ -468,6 +508,25 @@ ControlPropagationFn controlFn; }; +/// Wrapper pattern that applies foldFillPackIntoFillOp method. +struct FoldFillPackIntoFillOpPattern : public OpRewritePattern { +public: + FoldFillPackIntoFillOpPattern(MLIRContext *context, ControlPropagationFn fun) + : OpRewritePattern(context), controlFn(std::move(fun)) {} + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto fillOp = foldFillPackIntoFillOp(rewriter, packOp, controlFn); + if (failed(fillOp)) + return failure(); + rewriter.replaceOp(packOp, fillOp.value().result()); + return success(); + } + +private: + ControlPropagationFn controlFn; +}; + // TODO: Relax this restriction. We should unpack an elementwise also // in the presence of multiple unpack ops as producers. /// Return the unpacked operand, if present, for the current generic op. @@ -689,6 +748,7 @@ RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation) { patterns.insert( patterns.getContext(), controlPackUnPackPropagation); } diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -837,3 +837,48 @@ // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] // CHECK-SAME: into %[[UNPACK_NEW_DEST]] // CHECK: return %[[UNPACK]] : tensor<16x540x960xi32> + +// ----- + +func.func @fill_pack() -> tensor<24x32x16x16xf32> { + %dest = tensor.empty() : tensor<384x512xf32> + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<24x32x16x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32> + %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32> + return %pack : tensor<24x32x16x16xf32> +} +// CHECK-LABEL: func.func @fill_pack +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32> +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]] +// CHECK: return %[[FILL]] + +// ----- + +#map = affine_map<()[s0] -> (s0 ceildiv 16)> +func.func @dynamic_fill_pack(%arg0: tensor) -> tensor { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor) -> tensor + %dim = tensor.dim %0, %c0 : tensor + %dim_0 = tensor.dim %0, %c1 : tensor + %1 = affine.apply #map()[%dim] + %2 = affine.apply #map()[%dim_0] + %3 = tensor.empty(%1, %2) : tensor + %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor -> tensor + return %pack : tensor +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> +// CHECK: func.func @dynamic_fill_pack +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]] +// CHECK: %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]] +// CHECK: %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]] +// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]] +// CHECK: return %[[FILL]]