diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -57,6 +57,9 @@ std::unique_ptr> createConvertLinalgOnTensorsToBuffersPass(); +/// Create a pass that removes unnecessary Linalg Copy operations. +std::unique_ptr createLinalgCopyRemovalPass(); + /// Patterns for fusing linalg operation on tensors. void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, OwningRewritePatternList &patterns); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -11,6 +11,11 @@ include "mlir/Pass/PassBase.td" +def LinalgCopyRemoval : FunctionPass<"linalg-copy-removal"> { + let summary = "Remove the redundant copies from the input IR"; + let constructor = "mlir::createLinalgCopyRemovalPass()"; +} + def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRLinalgTransforms + CopyRemoval.cpp DropUnitDims.cpp Fusion.cpp Hoisting.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/CopyRemoval.cpp b/mlir/lib/Dialect/Linalg/Transforms/CopyRemoval.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/CopyRemoval.cpp @@ -0,0 +1,131 @@ +//===- CopyRemoval.cpp - Removing the redundant Linalg copies -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the pass for removing redundant Linalg.CopyOp. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace MemoryEffects; + +//===----------------------------------------------------------------------===// +// CopyRemovalPass +//===----------------------------------------------------------------------===// +/// This pass removes the redundant `to` value and its copy operation and +/// replaces all of it uses with `from` value. +/// +/// Input: +/// %from = ... +/// ... +/// %to = ... +/// ... +/// use(%from) +/// ... +/// copy(%from,%to) +/// ... +/// use(%to) +/// ... +/// dealloc(%from) +/// ... +/// use(%to) +/// +/// Output: +/// %from = ... +/// ... +/// use(%from) +/// ... +/// use(%from) +/// ... +/// use(%from) +/// +/// Limitations: +/// 1) `from`, `to`, copy(%from,%to), and dealloc(%from) must all be defined and +/// be in the same block for safe removing without considering CFG. +/// 2) This transformation cannot be applied if there is a single user of `to` +/// value between the defining operation of `to` and the copy operation. +/// 3) This transformation cannot be applied if there is a single user of `from` +/// value between the copy operation and the deallocation of `from`. + +class CopyRemovalPass : public PassWrapper> { +public: + void runOnOperation() override { + SmallVector eraseList; + getOperation()->walk([&](linalg::CopyOp copyOp) { + Value from = copyOp.input(); + Value to = copyOp.output(); + + // Find the deallocation operation for `from' value. + auto fromUsers = from.getUsers(); + auto it = llvm::find_if(fromUsers, [&](Operation *op) { + auto effects = dyn_cast(op); + return effects && effects.hasEffect() && + llvm::any_of(op->getOperands(), + [&](Value operand) { return operand == from; }); + }); + + // The deallocation operation for 'to' value. + Operation *toFreeingOp = (it == fromUsers.end() ? nullptr : *it); + // The defining operation for 'from' value. + Operation *fromDefiningOp = from.getDefiningOp(); + // The defining operation for 'to' value. + Operation *toDefiningOp = to.getDefiningOp(); + Operation *copy = copyOp.getOperation(); + + // Check if these operation are all defined. + if (!fromDefiningOp || !toFreeingOp || !toDefiningOp) + return; + + // Check if they are all in the same block. + DenseSet blocks; + blocks.reserve(4); + blocks.insert(fromDefiningOp->getBlock()); + blocks.insert(toFreeingOp->getBlock()); + blocks.insert(toDefiningOp->getBlock()); + blocks.insert(copy->getBlock()); + if (blocks.size() > 1) + return; + + // Return true if `val` value has at least a user between `start` and + // `end` operations. + auto hasUsers = [](Value val, Operation *start, Operation *end) -> bool { + return llvm::any_of(val.getUsers(), [&](Operation *op) { + return op->isBeforeInBlock(end) && start->isBeforeInBlock(op); + }); + }; + + // If `to` value has a single user between the defining operation of `to` + // and `copy` operation, it is illegal to remove the current copy + // operation. The same rule applies to `from` value if it has a user + // between copy operation and the deallocation operation of `to` value. + if (hasUsers(to, toDefiningOp, copy) || hasUsers(from, copy, toFreeingOp)) + return; + + // Copy operation can be safely removed and the uses of `to' value can be + // replaced with the `from` value. + to.replaceAllUsesWith(from); + copy->erase(); + eraseList.push_back(toDefiningOp); + eraseList.push_back(toFreeingOp); + }); + for (Operation *op : eraseList) + op->erase(); + } +}; + +//===----------------------------------------------------------------------===// +// CopyRemovalPass construction +//===----------------------------------------------------------------------===// + +std::unique_ptr mlir::createLinalgCopyRemovalPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/copy-removal.mlir b/mlir/test/Dialect/Linalg/copy-removal.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/copy-removal.mlir @@ -0,0 +1,184 @@ +// RUN: mlir-opt -linalg-copy-removal -split-input-file %s | FileCheck %s + +// All linalg copies except the linalg.copy(%1, %9) must be removed since the +// defining operation of %1 and its DeallocOp have been defined in another block. + +// CHECK-LABEL: func @nested_region_control_flow_div_nested +func @nested_region_control_flow_div_nested(%arg0: index, %arg1: index) -> memref { + %0 = cmpi "eq", %arg0, %arg1 : index + %1 = alloc(%arg0, %arg0) : memref + // CHECK: %{{.*}} = scf.if + %2 = scf.if %0 -> (memref) { + // CHECK: %[[PERCENT3:.*]] = scf.if + %3 = scf.if %0 -> (memref) { + %c0_0 = constant 0 : index + %7 = dim %1, %c0_0 : memref + %c1_1 = constant 1 : index + %8 = dim %1, %c1_1 : memref + %9 = alloc(%7, %8) : memref + // CHECK: linalg.copy({{.*}}, %[[PERCENT9:.*]]) + linalg.copy(%1, %9) : memref, memref + // CHECK: scf.yield %[[PERCENT9]] + scf.yield %9 : memref + } else { + // CHECK: %[[PERCENT7:.*]] = alloc + %7 = alloc(%arg0, %arg1) : memref + %c0_0 = constant 0 : index + %8 = dim %7, %c0_0 : memref + %c1_1 = constant 1 : index + %9 = dim %7, %c1_1 : memref + // CHECK-NOT: %[[PERCENT10]] = alloc + // CHECK-NOT: linalg.copy(%[[PERCENT7]], %[[PERCENT10]]) + // CHECK-NOT: dealloc %[[PERCENT7]]) + %10 = alloc(%8, %9) : memref + linalg.copy(%7, %10) : memref, memref + dealloc %7 : memref + // CHECK: scf.yield %[[PERCENT7]] + scf.yield %10 : memref + } + %c0 = constant 0 : index + %4 = dim %3, %c0 : memref + %c1 = constant 1 : index + %5 = dim %3, %c1 : memref + // CHECK-NOT: %[[PERCENT6]] = alloc + // CHECK-NOT: linalg.copy(%[[PERCENT3]], %[[PERCENT6]]) + // CHECK-NOT: dealloc %[[PERCENT3]]) + %6 = alloc(%4, %5) : memref + linalg.copy(%3, %6) : memref, memref + dealloc %3 : memref + // CHECK: scf.yield %[[PERCENT3]] + scf.yield %6 : memref + } else { + // CHECK: %[[PERCENT3:.*]] = alloc + %3 = alloc(%arg1, %arg1) : memref + %c0 = constant 0 : index + %4 = dim %3, %c0 : memref + %c1 = constant 1 : index + %5 = dim %3, %c1 : memref + // CHECK-NOT: %[[PERCENT6]] = alloc + // CHECK-NOT: linalg.copy(%[[PERCENT3]], %[[PERCENT6]]) + // CHECK-NOT: dealloc %[[PERCENT3]]) + %6 = alloc(%4, %5) : memref + linalg.copy(%3, %6) : memref, memref + dealloc %3 : memref + // CHECK: scf.yield %[[PERCENT3]] + scf.yield %6 : memref + } + dealloc %1 : memref + return %2 : memref +} + +// ----- + +// CHECK-LABEL: func @simple_test +func @simple_test() -> memref<5xf32> { + %temp = alloc() : memref<5xf32> + %ret = alloc() : memref<5xf32> + linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> + dealloc %ret : memref<5xf32> + return %temp : memref<5xf32> +} +// CHECK: () -> memref<5xf32> +// CHECK: %[[ret:.*]] = alloc() +// CHECK: return %[[ret]] + +// ----- + +// It is legal to remove the copy operation that %ret has a usage before the copy +// operation. The allocation of %temp and the deallocation of %ret should be also +// removed. + +// CHECK-LABEL: func @test_with_ret_usage_before_copy +func @test_with_ret_usage_before_copy() -> memref<5xf32> { + %ret = alloc() : memref<5xf32> + %temp = alloc() : memref<5xf32> + %c0 = constant 0 : index + %dimension = dim %ret, %c0 : memref<5xf32> + linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> + dealloc %ret : memref<5xf32> + return %temp : memref<5xf32> +} +// CHECK: %[[ret:.*]] = alloc() +// CHECK: %{{.*}} = constant +// CHECK: %[[DIM:.*]] = dim %[[ret]] +// CHECK: return %[[ret]] + +// ----- + +// It is illegal to remove a copy operation that %ret has a usage after copy +// operation. + +// CHECK-LABEL: func @test_with_ret_usage_after_copy +func @test_with_ret_usage_after_copy() -> memref<5xf32> { + %ret = alloc() : memref<5xf32> + %temp = alloc() : memref<5xf32> + // CHECK: linalg.copy + linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> + %c0 = constant 0 : index + %dimension = dim %ret, %c0 : memref<5xf32> + dealloc %ret : memref<5xf32> + return %temp : memref<5xf32> +} + +// ----- + +// It is illegal to remove a copy operation that %temp has a usage before copy +// operation. + +// CHECK-LABEL: func @test_with_temp_usage_before_copy +func @test_with_temp_usage_before_copy() -> memref<5xf32> { + %ret = alloc() : memref<5xf32> + %temp = alloc() : memref<5xf32> + %c0 = constant 0 : index + %dimension = dim %temp, %c0 : memref<5xf32> + // CHECK: linalg.copy + linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> + dealloc %ret : memref<5xf32> + return %temp : memref<5xf32> +} + +// ----- + +// It is legal to remove the copy operation that %temp has a usage after the copy +// operation. The allocation of %temp and the deallocation of %ret should be also +// removed. + +// CHECK-LABEL: func @test_with_temp_usage_after_copy +func @test_with_temp_usage_after_copy() -> memref<5xf32> { + %ret = alloc() : memref<5xf32> + %temp = alloc() : memref<5xf32> + linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> + %c0 = constant 0 : index + %dimension = dim %temp, %c0 : memref<5xf32> + dealloc %ret : memref<5xf32> + return %temp : memref<5xf32> +} +// CHECK: %[[ret:.*]] = alloc() +// CHECK-NOT: %{{.*}} = alloc() +// CHECK-NOT: linalg.copy +// CHECK: %{{.*}} = constant +// CHECK: %[[DIM:.*]] = dim %[[ret]] +// CHECK-NOT: dealloc %[[ret]] +// CHECK: return %[[ret]] + +// ----- + +// CHECK-LABEL: func @make_allocation +func @make_allocation() -> memref<5xf32> { + %mem = alloc() : memref<5xf32> + return %mem : memref<5xf32> +} + +// CHECK-LABEL: func @test_with_function_call +func @test_with_function_call() -> memref<5xf32> { + // CHECK: %[[ret:.*]] = call @make_allocation() : () -> memref<5xf32> + %ret = call @make_allocation() : () -> (memref<5xf32>) + // CHECK-NOT: %[[temp]] = alloc + // CHECK-NOT: linalg.copy(%[[ret]], %[[temp]]) + // CHECK-NOT: dealloc %[[ret]]) + %temp = alloc() : memref<5xf32> + linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> + dealloc %ret : memref<5xf32> + // CHECK: return %[[ret]] + return %temp : memref<5xf32> +}