diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -0,0 +1,26 @@ +//===- Transforms.h - Tensor Transformation Patterns ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace tensor { + +/// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op +/// to separate the cases where we don't need padding (all pad sizes are +/// actually zeros) and where we indeed need padding. +void populateSplitPaddingPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit = 1); + +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + SplitPadding.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms diff --git a/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp @@ -0,0 +1,94 @@ +//===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to wrap a tensor.pad op with an scf.if op +/// to separate the cases where we don't need padding (all pad sizes are +/// actually zeros) and where we indeed need padding. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-tensor-split-padding" + +using namespace mlir; + +/// Returns true if the the given `attrOrValue` is a constant zero. +static bool isZero(OpFoldResult attrOrValue) { + if (Optional val = getConstantIntValue(attrOrValue)) + return val.getValue() == 0; + return false; +} + +/// Gets the given `attrOrValue` as a Value by creating constant ops for +/// attributes. +static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + if (Value val = attrOrValue.dyn_cast()) + return val; + auto attr = attrOrValue.get().cast(); + return builder.create(loc, attr.getInt()); +} + +namespace { + +struct SplitPadding final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + // Avoid infinitely applying this pattern. + if (padOp->getParentOfType()) + return failure(); + + // If all padding sizes are zero, we don't need to do anything. + SmallVector lowPads = padOp.getMixedLowPad(); + SmallVector highPads = padOp.getMixedHighPad(); + if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) + return failure(); + + // Build the condition for the scf.if op: all pad sizes are zero. + Location loc = padOp.getLoc(); + Value cstZero = rewriter.create(loc, 0); + SmallVector eqZeroCmpVals; + for (OpFoldResult pad : llvm::concat(lowPads, highPads)) { + if (!isZero(pad)) + eqZeroCmpVals.push_back(rewriter.create( + loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc), + cstZero)); + } + Value ifCond = eqZeroCmpVals.front(); + for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front()) + ifCond = rewriter.create(loc, ifCond, cmp); + + // Build the scf.if op itself. For the "then" branch, we can elide the + // padding. For the "else" branch, we retain the clone op. + auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) { + builder.create(loc, padOp.source()); + }; + auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) { + Operation *newOp = builder.clone(*padOp); + builder.create(loc, newOp->getResults()); + }; + rewriter.replaceOpWithNewOp(padOp, padOp.getType(), ifCond, + thenBuilder, elseBuilder); + return success(); + } +}; + +} // namespace + +void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit) { + patterns.add(patterns.getContext(), baseBenefit); +} diff --git a/mlir/test/Dialect/Tensor/split-padding.mlir b/mlir/test/Dialect/Tensor/split-padding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/split-padding.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-split-padding-patterns %s | FileCheck %s + +// CHECK-LABEL: func @pad_all_zero_sizes +func @pad_all_zero_sizes(%input: tensor) -> tensor { + %f0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %0 = tensor.pad %input low[0, %c0, 0] high[%c0, 0, 0] { + ^bb0(%dim0: index, %dim1: index, %dim2: index): + tensor.yield %f0 : f32 + } : tensor to tensor + return %0 : tensor +} + +// CHECK-NOT: scf.if +// CHECK: tensor.pad + +// ----- + +// CHECK-LABEL: func @pad_non_zero_sizes +// CHECK-SAME: (%[[INPUT:.+]]: tensor, %[[LOW0:.+]]: index, %[[HIGH1:.+]]: index) +func @pad_non_zero_sizes(%input: tensor, %low0: index, %high1: index) -> tensor { + %f0 = arith.constant 0.0 : f32 + %0 = tensor.pad %input low[%low0, 0, 0] high[0, %high1, 0] { + ^bb0(%dim0: index, %dim1: index, %dim2: index): + tensor.yield %f0 : f32 + } : tensor to tensor + return %0 : tensor +} + +// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[EQ0:.+]] = arith.cmpi eq, %[[LOW0]], %[[C0]] : index +// CHECK: %[[EQ1:.+]] = arith.cmpi eq, %[[HIGH1]], %[[C0]] : index +// CHECK: %[[AND:.+]] = arith.andi %[[EQ0]], %[[EQ1]] : i1 +// CHECK: %[[IF:.+]] = scf.if %[[AND]] -> (tensor) { +// CHECK: scf.yield %[[INPUT]] : tensor +// CHECK: } else { +// CHECK: %[[PAD:.+]] = tensor.pad %[[INPUT]] low[%[[LOW0]], 0, 0] high[0, %[[HIGH1]], 0] { +// CHECK: ^bb0(%{{.+}}: index, %{{.+}}: index, %{{.+}}: index): +// CHECK: tensor.yield %[[F0]] : f32 +// CHECK: } : tensor to tensor +// CHECK: scf.yield %[[PAD]] : tensor +// CHECK: } +// CHECK: return %[[IF]] : tensor diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tensor) add_subdirectory(Test) add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt @@ -0,0 +1,12 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRTensorTestPasses + TestTensorTransforms.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRArithmetic + MLIRTensorTransforms + MLIRPass + MLIRSCF + ) diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -0,0 +1,65 @@ +//===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing Tensor transformations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestTensorTransforms + : public PassWrapper> { + TestTensorTransforms() = default; + TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { + return "test-tensor-transform-patterns"; + } + StringRef getDescription() const final { + return "Test Tensor transformation patterns by applying them greedily."; + } + + void runOnOperation() override; + + Option testSplitPaddingPatterns{ + *this, "test-split-padding-patterns", + llvm::cl::desc("Test patterns to split tensor.pad ops"), + llvm::cl::init(false)}; +}; +} // namespace + +static void applySplitPaddingPatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + tensor::populateSplitPaddingPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + +void TestTensorTransforms::runOnOperation() { + FuncOp func = getOperation(); + if (testSplitPaddingPatterns) + applySplitPaddingPatterns(func); +} + +namespace mlir { +namespace test { +void registerTestTensorTransforms() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -23,7 +23,7 @@ MLIRShapeTestPasses MLIRSPIRVTestPasses MLIRStandardOpsTestPasses - MLIRVectorTestPasses + MLIRTensorTestPasses MLIRTestAnalysis MLIRTestDialect MLIRTestIR @@ -31,6 +31,7 @@ MLIRTestReducer MLIRTestRewrite MLIRTestTransforms + MLIRVectorTestPasses ) endif() diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -106,6 +106,7 @@ void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); void registerTestSliceAnalysisPass(); +void registerTestTensorTransforms(); void registerTestVectorLowerings(); } // namespace test } // namespace mlir @@ -194,6 +195,7 @@ mlir::test::registerTestRecursiveTypesPass(); mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSliceAnalysisPass(); + mlir::test::registerTestTensorTransforms(); mlir::test::registerTestVectorLowerings(); } #endif diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4645,6 +4645,7 @@ hdrs = [ "include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/Tensor/Transforms/Passes.h", + "include/mlir/Dialect/Tensor/Transforms/Transforms.h", ], includes = ["include"], deps = [ @@ -4652,6 +4653,7 @@ ":Async", ":BufferizationDialect", ":BufferizationTransforms", + ":DialectUtils", ":IR", ":MemRefDialect", ":ParallelLoopMapperAttrGen",