diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h @@ -15,6 +15,10 @@ /// Creates an instance of `tensor` dialect bufferization pass. std::unique_ptr createTensorBufferizePass(); +/// Creates a pass to wrap tensor.pad ops with scf.if ops to allow handle +/// padding-elided and padding-needed cases separately. +std::unique_ptr createTensorSplitPaddingPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td @@ -16,4 +16,14 @@ let constructor = "mlir::createTensorBufferizePass()"; } +def TensorSplitPadding : Pass<"tensor-split-padding", "FuncOp"> { + let summary = "Split `tensor.pad` op into padding-unnecessary and " + "padding-needed cases"; + let description = [{ + This pass creates scf.if ops to wrap tensor.pad ops to allow handle + padding-elided and padding-needed cases separately. + }]; + let constructor = "mlir::createTensorSplitPaddingPass()"; +} + #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -0,0 +1,26 @@ +//===- Transforms.h - Tensor Transformation Patterns ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace tensor { + +/// Populates `patterns` with patterns that split linalg.pad_tensor ops by +/// creating scf.if ops to wrap linalg.pad_tensor ops and handle +/// padding-unncessary and padding-needed cases separately. +void populateSplitPaddingPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit = 1); + +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + SplitPadding.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms diff --git a/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp @@ -0,0 +1,120 @@ +//===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns and passes for creating scf.if ops to wrap +// tensor.pad ops to allow handle padding-elided and padding-needed cases +// separately. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-tensor-split-padding" + +using namespace mlir; + +/// Returns true if the the given `attrOrValue` is a constant zero. +static bool isZero(OpFoldResult attrOrValue) { + if (Optional val = getConstantIntValue(attrOrValue)) + return val.getValue() == 0; + return false; +} + +/// Gets the given `attrOrValue` as a Value by creating constant ops for +/// attributes. +static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + if (Value val = attrOrValue.dyn_cast()) + return val; + auto attr = attrOrValue.get().cast(); + return builder.create(loc, attr.getInt()); +} + +namespace { + +/// Splits a tensor.pad op by wrapping it in a scf.if op to handle +/// padding-unnecessary and padding-needed cases. +struct SplitPadding final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + // Avoid infinitely applying this pattern. + if (padOp->getParentOfType()) + return failure(); + + // If all padding sizes are zero, we don't need to do anything. + SmallVector lowPads = padOp.getMixedLowPad(); + SmallVector highPads = padOp.getMixedHighPad(); + if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) + return failure(); + + // Build the condition for the scf.if op: all pad sizes are zero. + Location loc = padOp.getLoc(); + Value cstZero = rewriter.create(loc, 0); + SmallVector eqZeroCmpVals; + for (OpFoldResult pad : llvm::concat(lowPads, highPads)) { + if (!isZero(pad)) { + eqZeroCmpVals.push_back(rewriter.create( + loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc), + cstZero)); + } + } + Value ifCond = eqZeroCmpVals.front(); + for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front()) { + ifCond = rewriter.create(loc, ifCond, cmp); + } + + // Build the scf.if op itself. For the "then" branch, we can elide the + // padding. For the "else" branch, we retain the clone op. + auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) { + builder.create(loc, padOp.source()); + }; + auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) { + Operation *newOp = builder.clone(*padOp); + builder.create(loc, newOp->getResults()); + }; + rewriter.replaceOpWithNewOp(padOp, padOp.getType(), ifCond, + thenBuilder, elseBuilder); + return success(); + } +}; + +struct TensorSplitPaddingPass final + : public TensorSplitPaddingBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + FuncOp fn = getOperation(); + RewritePatternSet patterns(&getContext()); + tensor::populateSplitPaddingPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(fn, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit) { + patterns.add(patterns.getContext(), baseBenefit); +} + +std::unique_ptr mlir::createTensorSplitPaddingPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Tensor/split-padding.mlir b/mlir/test/Dialect/Tensor/split-padding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/split-padding.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt -split-input-file -tensor-split-padding %s | FileCheck %s + +// CHECK-LABEL: func @pad_all_zero_sizes +func @pad_all_zero_sizes(%input: tensor) -> tensor { + %f0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %0 = tensor.pad %input low[0, %c0, 0] high[%c0, 0, 0] { + ^bb0(%dim0: index, %dim1: index, %dim2: index): + tensor.yield %f0 : f32 + } : tensor to tensor + return %0 : tensor +} + +// CHECK-NOT: scf.if +// CHECK: tensor.pad + +// ----- + +// CHECK-LABEL: func @pad_non_zero_sizes +// CHECK-SAME: (%[[INPUT:.+]]: tensor, %[[LOW0:.+]]: index, %[[HIGH1:.+]]: index) +func @pad_non_zero_sizes(%input: tensor, %low0: index, %high1: index) -> tensor { + %f0 = arith.constant 0.0 : f32 + %0 = tensor.pad %input low[%low0, 0, 0] high[0, %high1, 0] { + ^bb0(%dim0: index, %dim1: index, %dim2: index): + tensor.yield %f0 : f32 + } : tensor to tensor + return %0 : tensor +} + +// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[EQ0:.+]] = arith.cmpi eq, %[[LOW0]], %[[C0]] : index +// CHECK: %[[EQ1:.+]] = arith.cmpi eq, %[[HIGH1]], %[[C0]] : index +// CHECK: %[[AND:.+]] = arith.andi %[[EQ0]], %[[EQ1]] : i1 +// CHECK: %[[IF:.+]] = scf.if %[[AND]] -> (tensor) { +// CHECK: scf.yield %[[INPUT]] : tensor +// CHECK: } else { +// CHECK: %[[PAD:.+]] = tensor.pad %[[INPUT]] low[%[[LOW0]], 0, 0] high[0, %[[HIGH1]], 0] { +// CHECK: ^bb0(%{{.+}}: index, %{{.+}}: index, %{{.+}}: index): +// CHECK: tensor.yield %[[F0]] : f32 +// CHECK: } : tensor to tensor +// CHECK: scf.yield %[[PAD]] : tensor +// CHECK: } +// CHECK: return %[[IF]] : tensor diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4528,6 +4528,7 @@ hdrs = [ "include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/Tensor/Transforms/Passes.h", + "include/mlir/Dialect/Tensor/Transforms/Transforms.h", ], includes = ["include"], deps = [ @@ -4535,6 +4536,7 @@ ":Async", ":BufferizationDialect", ":BufferizationTransforms", + ":DialectUtils", ":IR", ":MemRefDialect", ":ParallelLoopMapperAttrGen",