diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1165,6 +1165,12 @@ const SmallVector &dynSizes) const; }; +/// Populates `patterns` with patterns that split linalg.pad_tensor ops by +/// creating scf.if ops to wrap linalg.pad_tensor ops and handle +/// padding-unncessary and padding-needed cases separately. +void populateSplitPaddingPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit = 1); + /// Populates `patterns` with patterns that vectorize linalg.pad_tensor. /// These patterns are meant to apply in a complementary fashion. Benefits /// are used to encode a certain ordering of pattern application. To avoid diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -18,6 +18,7 @@ LinalgStrategyPasses.cpp NamedOpConversions.cpp Promotion.cpp + SplitPadding.cpp Tiling.cpp Transforms.cpp Vectorization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitPadding.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitPadding.cpp @@ -0,0 +1,98 @@ +//===- SplitPadding.cpp - PadTensorOp Splitting ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns and passes for creating scf.if ops to wrap +// linalg.pad_tensor ops and handle padding-unncessary and padding-needed cases +// separately. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-linalg-split-padding" + +using namespace mlir; + +/// Returns true if the the given `attrOrValue` is a constant zero. +static bool isZero(OpFoldResult attrOrValue) { + if (Optional val = getConstantIntValue(attrOrValue)) + return val.getValue() == 0; + return false; +}; + +/// Gets the given `attrOrValue` as a Value by creating constant ops for +/// attributes. +static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + if (Value val = attrOrValue.dyn_cast()) + return val; + auto attr = attrOrValue.get().cast(); + return builder.create(loc, attr.getInt()); +} + +namespace { + +/// Splits a linalg.pad_tensor op by wrapping it in a scf.if op to handle +/// padding-unnecessary and padding-needed cases. +struct SplitPadding final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::PadTensorOp padOp, + PatternRewriter &rewriter) const override { + // Avoid infinitely applying this pattern. + if (padOp->getParentOfType()) + return failure(); + + // If all padding sizes are zero, we don't need to do anything. + SmallVector lowPads = padOp.getMixedLowPad(); + SmallVector highPads = padOp.getMixedHighPad(); + if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) + return failure(); + + // Build the condition for the scf.if op: all pad sizes are zero. + Location loc = padOp.getLoc(); + Value cstZero = rewriter.create(loc, 0); + SmallVector eqZeroCmpVals; + for (OpFoldResult pad : llvm::concat(lowPads, highPads)) { + if (!isZero(pad)) { + eqZeroCmpVals.push_back(rewriter.create( + loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc), + cstZero)); + } + } + Value ifCond = eqZeroCmpVals.front(); + for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front()) { + ifCond = rewriter.create(loc, ifCond, cmp); + } + + // Build the scf.if op itself. For the "then" branch, we can elide the + // padding. For the "else" branch, we retain the clone op. + auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) { + builder.create(loc, padOp.source()); + }; + auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) { + Operation *newOp = builder.clone(*padOp); + builder.create(loc, newOp->getResults()); + }; + rewriter.replaceOpWithNewOp(padOp, padOp.getType(), ifCond, + thenBuilder, elseBuilder); + return success(); + } +}; + +} // namespace + +void linalg::populateSplitPaddingPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit) { + patterns.add(patterns.getContext(), baseBenefit); +} diff --git a/mlir/test/Dialect/Linalg/split-padding.mlir b/mlir/test/Dialect/Linalg/split-padding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/split-padding.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-split-padding-patterns %s | FileCheck %s + +// CHECK-LABEL: func @pad_all_zero_sizes +func @pad_all_zero_sizes(%input: tensor) -> tensor { + %f0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %0 = linalg.pad_tensor %input low[0, %c0, 0] high[%c0, 0, 0] { + ^bb0(%dim0: index, %dim1: index, %dim2: index): + linalg.yield %f0 : f32 + } : tensor to tensor + return %0 : tensor +} + +// CHECK-NOT: scf.if +// CHECK: linalg.pad_tensor + +// ----- + +// CHECK-LABEL: func @pad_non_zero_sizes +// CHECK-SAME: (%[[INPUT:.+]]: tensor, %[[LOW0:.+]]: index, %[[HIGH1:.+]]: index) +func @pad_non_zero_sizes(%input: tensor, %low0: index, %high1: index) -> tensor { + %f0 = arith.constant 0.0 : f32 + %0 = linalg.pad_tensor %input low[%low0, 0, 0] high[0, %high1, 0] { + ^bb0(%dim0: index, %dim1: index, %dim2: index): + linalg.yield %f0 : f32 + } : tensor to tensor + return %0 : tensor +} + +// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[EQ0:.+]] = arith.cmpi eq, %[[LOW0]], %[[C0]] : index +// CHECK: %[[EQ1:.+]] = arith.cmpi eq, %[[HIGH1]], %[[C0]] : index +// CHECK: %[[AND:.+]] = arith.andi %[[EQ0]], %[[EQ1]] : i1 +// CHECK: %[[IF:.+]] = scf.if %[[AND]] -> (tensor) { +// CHECK: scf.yield %[[INPUT]] : tensor +// CHECK: } else { +// CHECK: %[[PAD:.+]] = linalg.pad_tensor %[[INPUT]] low[%[[LOW0]], 0, 0] high[0, %[[HIGH1]], 0] { +// CHECK: ^bb0(%{{.+}}: index, %{{.+}}: index, %{{.+}}: index): +// CHECK: linalg.yield %[[F0]] : f32 +// CHECK: } : tensor to tensor +// CHECK: scf.yield %[[PAD]] : tensor +// CHECK: } +// CHECK: return %[[IF]] : tensor diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -128,6 +128,11 @@ llvm::cl::desc("Specify the type of loops to generate: for, parallel or " "tiled_loop"), llvm::cl::init("for")}; + Option testSplitPaddingPattern{ + *this, "test-split-padding-patterns", + llvm::cl::desc("Test a set of patterns to split linalg.pad_tensor ops " + "and handle its cases separately"), + llvm::cl::init(false)}; }; } // namespace @@ -561,6 +566,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applySplitPaddingPatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateSplitPaddingPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -720,6 +731,8 @@ if (testTileScalarizeDynamicDims) return applyTilePattern(getFunction(), loopType, tileSizes, /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); + if (testSplitPaddingPattern) + return applySplitPaddingPatterns(getFunction()); } namespace mlir {