diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -24,6 +24,7 @@ #include "mlir/Parser.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" @@ -416,6 +417,22 @@ SideEffects::DefaultResource::get()); } +/// Gets the given `attrOrValue` as an index value by creating constant ops +/// for attributes. +static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + IntegerAttr attr; + if (Value val = attrOrValue.dyn_cast()) { + if (val.getType().isIndex()) + return val; + matchPattern(val, m_Constant(&attr)); + } else { + attr = attrOrValue.get().cast(); + } + return builder.createOrFold( + loc, attr.getValue().getSExtValue()); +} + namespace { /// Fold linalg.fill -> tensor.expand/collapse_shape chain. @@ -441,12 +458,68 @@ } }; +/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the +/// filling value are the same. +struct FoldFillWithPad final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + auto fillOp = padOp.source().getDefiningOp(); + if (!fillOp) + return failure(); + + // We can only fold if the padding value is the same as the original + // filling value. + Value padValue = padOp.getConstantPaddingValue(); + if (!padValue || fillOp.value() != padValue) + return failure(); + + auto sourceType = fillOp.output().getType().cast(); + auto resultType = padOp.getResultType(); + + SmallVector dynamicDims; + dynamicDims.resize(sourceType.getRank()); + for (int i = 0, e = sourceType.getRank(); i < e; ++i) { + if (resultType.isDynamicDim(i)) + dynamicDims[i] = rewriter.createOrFold( + fillOp.output().getLoc(), fillOp.output(), i); + } + + Location loc = padOp.getLoc(); + SmallVector lowPads = padOp.getMixedLowPad(); + SmallVector highPads = padOp.getMixedHighPad(); + + AffineExpr sym0, sym1, sym2; + bindSymbols(getContext(), sym0, sym1, sym2); + auto addMap = AffineMap::get(0, 3, {sym0 + sym1 + sym2}, getContext()); + for (int i = 0, e = sourceType.getRank(); i < e; ++i) { + if (resultType.isDynamicDim(i)) { + Value lowPad = getAsIndexValue(lowPads[i], rewriter, loc); + Value highPad = getAsIndexValue(highPads[i], rewriter, loc); + dynamicDims[i] = rewriter.create( + loc, addMap, ValueRange{lowPad, dynamicDims[i], highPad}); + } + } + + auto dims = llvm::to_vector<4>( + llvm::make_filter_range(dynamicDims, [](Value v) { return v; })); + auto initOp = rewriter.create( + loc, dims, padOp.getResultType().getShape(), + padOp.getResultType().getElementType()); + + rewriter.replaceOpWithNewOp(padOp, padValue, initOp); + return success(); + } +}; + } // namespace void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - FoldFillWithTensorReshape>(context); + results + .add, + FoldFillWithTensorReshape>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -585,3 +585,67 @@ } return } + +// ----- + +// CHECK-LABEL: func @fold_static_pad_fill +// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[INIT:.+]] = linalg.init_tensor [412, 276] : tensor<412x276xf32> +// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]]) +// CHECK: return %[[FILL]] +func @fold_static_pad_fill() -> tensor<412x276xf32> { + %f0 = arith.constant 0.0 : f32 + %init = linalg.init_tensor [400, 273] : tensor<400x273xf32> + %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32> + %pad = tensor.pad %fill low[4, 1] high[8, 2] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %f0 : f32 + } : tensor<400x273xf32> to tensor<412x276xf32> + return %pad : tensor<412x276xf32> +} + +// ----- + +// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 9)> +// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 10)> +// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 + 23)> +// CHECK: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 32)> + +// CHECK: func @fold_dynamic_pad_fill +// CHECK-SAME: %[[SRC:.+]]: tensor<8x?x16x32xf32>, %[[LOW0:.+]]: index, %[[LOW3:.+]]: index, %[[HIGH2:.+]]: index, %[[HIGH3:.+]]: index + +// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[DIM1:.+]] = tensor.dim %[[SRC]], %[[I1]] : tensor<8x?x16x32xf32> +// CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]] +// CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]] +// CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]] +// CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : tensor +// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]]) +// CHECK: return %[[FILL]] +func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor { + %f0 = arith.constant 0.0 : f32 + %fill = linalg.fill(%f0, %init) : f32, tensor<8x?x16x32xf32> -> tensor<8x?x16x32xf32> + %pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %f0 : f32 + } : tensor<8x?x16x32xf32> to tensor + return %pad : tensor +} + +// ----- + +// CHECK-LABEL: func @no_fold_pad_fill_value_mismatch +func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> { + %f0 = arith.constant 0.0 : f32 + %f1 = arith.constant 1.0 : f32 + %init = linalg.init_tensor [400, 273] : tensor<400x273xf32> + %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32> + // CHECK: tensor.pad + %pad = tensor.pad %fill low[4, 1] high[8, 2] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %f1 : f32 + } : tensor<400x273xf32> to tensor<412x276xf32> + return %pad : tensor<412x276xf32> +}