diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H #define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1706,4 +1706,41 @@ }]; } +//===----------------------------------------------------------------------===// +// HoistRedundantVectorTransfersOp +//===----------------------------------------------------------------------===// + +def HoistRedundantVectorTransfersOp : Op { + let description = [{ + Hoist vector.transfer_read / vector.transfer_write pairs out of immediately + enclosing scf::ForOp iteratively, if the following conditions are true: + 1. The 2 ops access the same memref with the same indices. + 2. All operands are invariant under the enclosing scf::ForOp. + 3. No uses of the memref either dominate the transfer_read or are + dominated by the transfer_write (i.e. no aliasing between the write and + the read across the loop) + + #### Return modes: + + The operation always returns the handle to the transformed function op. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let builders = [ + OpBuilder<(ins "Value":$target)>, + ]; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::func::FuncOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -10,6 +10,7 @@ LINK_LIBS PUBLIC MLIRAffineDialect MLIRArithDialect + MLIRFuncDialect MLIRIR MLIRLinalgDialect MLIRLinalgTransforms diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -3058,6 +3059,18 @@ return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b); } +//===----------------------------------------------------------------------===// +// HoistRedundantVectorTransfersOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::HoistRedundantVectorTransfersOp::applyToOne( + func::FuncOp target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + linalg::hoistRedundantVectorTransfers(target); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-hoisting.mlir b/mlir/test/Dialect/Linalg/transform-op-hoisting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-hoisting.mlir @@ -0,0 +1,84 @@ +// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file --allow-unregistered-dialect %s | FileCheck %s + +// CHECK-LABEL: func @hoist_vector_transfer_pairs( +// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, +// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref, +// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref, +// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref, +// CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref, +// CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref, +// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1 +func.func @hoist_vector_transfer_pairs( + %memref0: memref, %memref1: memref, %memref2: memref, + %memref3: memref, %memref4: memref, %memref5: memref, + %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + +// CHECK: vector.transfer_read %{{.*}} : memref, vector<1xf32> +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) { +// CHECK: vector.transfer_read %{{.*}} : memref, vector<2xf32> +// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) { +// CHECK: vector.transfer_read %{{.*}} : memref, vector<3xf32> +// CHECK: vector.transfer_read %{{.*}} : memref, vector<4xf32> +// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref) -> () +// CHECK: vector.transfer_read %{{.*}} : memref, vector<5xf32> +// CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> +// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// CHECK: "some_use"(%[[MEMREF2]]) : (memref) -> vector<3xf32> +// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> +// CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32> +// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref +// CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, memref +// CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, memref +// CHECK: "some_crippling_use"(%[[MEMREF3]]) : (memref) -> () +// CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32> +// CHECK: } +// CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref +// CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref) -> () +// CHECK: scf.yield {{.*}} : vector<1xf32> +// CHECK: } +// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref +// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref) -> () + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref, vector<1xf32> + %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref, vector<2xf32> + %r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref, vector<3xf32> + %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref, vector<4xf32> + "some_crippling_use"(%memref4) : (memref) -> () + %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref, vector<5xf32> + %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref, vector<6xf32> + "some_crippling_use"(%memref5) : (memref) -> () + %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> + %u2 = "some_use"(%memref2) : (memref) -> vector<3xf32> + %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> + %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> + %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> + vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref + vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref + vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref + vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref + vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref + vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref + "some_crippling_use"(%memref3) : (memref) -> () + } + "unrelated_use"(%memref0) : (memref) -> () + } + "unrelated_use"(%memref1) : (memref) -> () + return +} + + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8350,6 +8350,7 @@ ":AsmParser", ":ControlFlowDialect", ":DialectUtils", + ":FuncDialect", ":GPUDialect", ":IR", ":LinalgDialect",