diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H #define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1706,4 +1706,43 @@ }]; } +//===----------------------------------------------------------------------===// +// HoistRedundantVectorTransfersOp +//===----------------------------------------------------------------------===// + +def HoistRedundantVectorTransfersOp : + Op { + let description = [{ + Hoist vector.transfer_read / vector.transfer_write pairs out of immediately + enclosing scf::ForOp iteratively, if the following conditions are true: + 1. The 2 ops access the same memref with the same indices. + 2. All operands are invariant under the enclosing scf::ForOp. + 3. No uses of the memref either dominate the transfer_read or are + dominated by the transfer_write (i.e. no aliasing between the write and + the read across the loop) + + #### Return modes: + + The operation always succeeds and returns a handle to the transformed + function op. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let builders = [ + OpBuilder<(ins "Value":$target)>, + ]; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::func::FuncOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -10,6 +10,7 @@ LINK_LIBS PUBLIC MLIRAffineDialect MLIRArithDialect + MLIRFuncDialect MLIRIR MLIRLinalgDialect MLIRLinalgTransforms diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -3058,6 +3059,19 @@ return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b); } +//===----------------------------------------------------------------------===// +// HoistRedundantVectorTransfersOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::HoistRedundantVectorTransfersOp::applyToOne( + func::FuncOp target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + linalg::hoistRedundantVectorTransfers(target); + linalg::hoistRedundantVectorTransfersOnTensor(target); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect -split-input-file | FileCheck %s +// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file --allow-unregistered-dialect %s | FileCheck %s // CHECK-LABEL: func @hoist_vector_transfer_pairs( // CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, @@ -74,6 +74,14 @@ return } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint( @@ -155,6 +163,14 @@ return } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor @@ -236,6 +252,14 @@ tensor, tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor( @@ -323,6 +347,14 @@ return %0#0, %0#1, %0#2, %0#3 : tensor, tensor, tensor, tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices @@ -432,6 +464,14 @@ return %0#0, %0#1, %0#2 : tensor, tensor, tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor( @@ -469,6 +509,14 @@ return %1 : tensor } +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} + // ----- // CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( @@ -505,3 +553,11 @@ } return } + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -4,7 +4,6 @@ TestLinalgDecomposeOps.cpp TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp - TestLinalgHoisting.cpp TestLinalgTransforms.cpp TestPadFusion.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp deleted file mode 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp +++ /dev/null @@ -1,58 +0,0 @@ -//===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements logic for testing Linalg hoisting functions. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; -using namespace mlir::linalg; - -namespace { -struct TestLinalgHoisting - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgHoisting) - - TestLinalgHoisting() = default; - TestLinalgHoisting(const TestLinalgHoisting &pass) : PassWrapper(pass) {} - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - StringRef getArgument() const final { return "test-linalg-hoisting"; } - StringRef getDescription() const final { - return "Test Linalg hoisting functions."; - } - - void runOnOperation() override; - - Option testHoistRedundantTransfers{ - *this, "test-hoist-redundant-transfers", - llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"), - llvm::cl::init(false)}; -}; -} // namespace - -void TestLinalgHoisting::runOnOperation() { - if (testHoistRedundantTransfers) { - hoistRedundantVectorTransfers(getOperation()); - hoistRedundantVectorTransfersOnTensor(getOperation()); - return; - } -} - -namespace mlir { -namespace test { -void registerTestLinalgHoisting() { PassRegistration(); } -} // namespace test -} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -95,7 +95,6 @@ void registerTestLinalgDecomposeOps(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgGreedyFusion(); -void registerTestLinalgHoisting(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); void registerTestLoopFusion(); @@ -205,7 +204,6 @@ mlir::test::registerTestLinalgDecomposeOps(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgGreedyFusion(); - mlir::test::registerTestLinalgHoisting(); mlir::test::registerTestLinalgTransforms(); mlir::test::registerTestLivenessPass(); mlir::test::registerTestLoopFusion(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8350,6 +8350,7 @@ ":AsmParser", ":ControlFlowDialect", ":DialectUtils", + ":FuncDialect", ":GPUDialect", ":IR", ":LinalgDialect",