diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -19,6 +19,7 @@ #include "mlir/Transforms/ViewOpGraph.h" #include "llvm/Support/Debug.h" #include +#include namespace mlir { @@ -105,6 +106,9 @@ createInlinerPass(llvm::StringMap opPipelines, std::function defaultPipelineBuilder); +/// Creates an optimization pass to remove dead values. +std::unique_ptr createRemoveDeadValuesPass(); + /// Creates a pass which performs sparse conditional constant propagation over /// nested operations. std::unique_ptr createSCCPPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -85,6 +85,163 @@ ]; } +def RemoveDeadValues : Pass<"remove-dead-values"> { + let summary = "Remove dead values"; + let description = [{ + The goal of this pass is optimization (reducing runtime) by removing + unnecessary instructions. Unlike other passes that rely on local information + gathered from patterns to accomplish optimization, this pass uses a full + analysis of the IR, specifically, liveness analysis, and is thus more + powerful. + + Currently, this pass performs the following optimizations: + (A) Removes function arguments that are not live, + (B) Removes function return values that are not live across all callers of + the function, + (C) Removes unneccesary operands, results, region arguments, and region + terminator operands of region branch ops, and, + (D) Removes simple and region branch ops that have all non-live results and + don't affect memory in any way, + + iff + + the IR doesn't have any non-function symbol ops, non-call symbol user ops + and branch ops. + + Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op, + region branch op, branch op, region branch terminator op, or return-like. + + It is noteworthy that we do not refer to non-live values as "dead" in this + file to avoid confusing it with dead code analysis's "dead", which refers to + unreachable code (code that never executes on hardware) while "non-live" + refers to code that executes on hardware but is unnecessary. Thus, while the + removal of dead code helps little in reducing runtime, removing non-live + values should theoretically have significant impact (depending on the amount + removed). + + It is also important to note that unlike other passes (like `canonicalize`) + that apply op-specific optimizations through patterns, this pass uses + different interfaces to handle various types of ops and tries to cover all + existing ops through these interfaces. + + It is because of its reliance on (a) liveness analysis and (b) interfaces + that makes it so powerful that it can optimize ops that don't have a + canonicalizer and even when an op does have a canonicalizer, it can perform + more aggressive optimizations, as observed in the test files associated with + this pass. + + Example of optimization (A):- + + ``` + int add_2_to_y(int x, int y) { + return 2 + y + } + + print(add_2_to_y(3, 4)) + print(add_2_to_y(5, 6)) + ``` + + becomes + + ``` + int add_2_to_y(int y) { + return 2 + y + } + + print(add_2_to_y(4)) + print(add_2_to_y(6)) + ``` + + Example of optimization (B):- + + ``` + int, int get_incremented_values(int y) { + store y somewhere in memory + return y + 1, y + 2 + } + + y1, y2 = get_incremented_values(4) + y3, y4 = get_incremented_values(6) + print(y2) + ``` + + becomes + + ``` + int get_incremented_values(int y) { + store y somewhere in memory + return y + 2 + } + + y2 = get_incremented_values(4) + y4 = get_incremented_values(6) + print(y2) + ``` + + Example of optimization (C):- + + Assume only `%result1` is live here. Then, + + ``` + %result1, %result2, %result3 = scf.while (%arg1 = %operand1, %arg2 = %operand2) { + %terminator_operand2 = add %arg2, %arg2 + %terminator_operand3 = mul %arg2, %arg2 + %terminator_operand4 = add %arg1, %arg1 + scf.condition(%terminator_operand1) %terminator_operand2, %terminator_operand3, %terminator_operand4 + } do { + ^bb0(%arg3, %arg4, %arg5): + %terminator_operand6 = add %arg4, %arg4 + %terminator_operand5 = add %arg5, %arg5 + scf.yield %terminator_operand5, %terminator_operand6 + } + ``` + + becomes + + ``` + %result1, %result2 = scf.while (%arg2 = %operand2) { + %terminator_operand2 = add %arg2, %arg2 + %terminator_operand3 = mul %arg2, %arg2 + scf.condition(%terminator_operand1) %terminator_operand2, %terminator_operand3 + } do { + ^bb0(%arg3, %arg4): + %terminator_operand6 = add %arg4, %arg4 + scf.yield %terminator_operand6 + } + ``` + + It is interesting to see that `%result2` won't be removed even though it is + not live because `%terminator_operand3` forwards to it and cannot be + removed. And, that is because it also forwards to `%arg4`, which is live. + + Example of optimization (D):- + + ``` + int square_and_double_of_y(int y) { + square = y ^ 2 + double = y * 2 + return square, double + } + + sq, do = square_and_double_of_y(5) + print(do) + ``` + + becomes + + ``` + int square_and_double_of_y(int y) { + double = y * 2 + return double + } + + do = square_and_double_of_y(5) + print(do) + ``` + }]; + let constructor = "mlir::createRemoveDeadValuesPass()"; +} + def PrintIRPass : Pass<"print-ir"> { let summary = "Print IR on the debug stream"; let description = [{ diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ Mem2Reg.cpp OpStats.cpp PrintIR.cpp + RemoveDeadValues.cpp SCCP.cpp SROA.cpp StripDebugInfo.cpp diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -0,0 +1,619 @@ +//===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The goal of this pass is optimization (reducing runtime) by removing +// unnecessary instructions. Unlike other passes that rely on local information +// gathered from patterns to accomplish optimization, this pass uses a full +// analysis of the IR, specifically, liveness analysis, and is thus more +// powerful. +// +// Currently, this pass performs the following optimizations: +// (A) Removes function arguments that are not live, +// (B) Removes function return values that are not live across all callers of +// the function, +// (C) Removes unneccesary operands, results, region arguments, and region +// terminator operands of region branch ops, and, +// (D) Removes simple and region branch ops that have all non-live results and +// don't affect memory in any way, +// +// iff +// +// the IR doesn't have any non-function symbol ops, non-call symbol user ops and +// branch ops. +// +// Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op, +// region branch op, branch op, region branch terminator op, or return-like. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/LivenessAnalysis.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/FunctionInterfaces.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include +#include +#include +#include +#include + +namespace mlir { +#define GEN_PASS_DEF_REMOVEDEADVALUES +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::dataflow; + +//===----------------------------------------------------------------------===// +// RemoveDeadValues Pass +//===----------------------------------------------------------------------===// + +namespace { + +// Some helper functions... + +/// Return true iff at least one value in `values` is live, given the liveness +/// information in `la`. +static bool hasLive(ValueRange values, RunLivenessAnalysis &la) { + for (Value value : values) { + // If there is a null value, it implies that it was dropped during the + // execution of this pass, implying that it was non-live. + if (!value) + continue; + + const Liveness *liveness = la.getLiveness(value); + if (!liveness || liveness->isLive) + return true; + } + return false; +} + +/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the +/// i-th value in `values` is live, given the liveness information in `la`. +static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) { + BitVector lives(values.size(), true); + + for (auto [index, value] : llvm::enumerate(values)) { + if (!value) { + lives.reset(index); + continue; + } + + const Liveness *liveness = la.getLiveness(value); + // It is important to note that when `liveness` is null, we can't tell if + // `value` is live or not. So, the safe option is to consider it live. Also, + // the execution of this pass might create new SSA values when erasing some + // of the results of an op and we know that these new values are live + // (because they weren't erased) and also their liveness is null because + // liveness analysis ran before their creation. + if (liveness && !liveness->isLive) + lives.reset(index); + } + + return lives; +} + +/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i] +/// is 1. +static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { + assert(op->getNumResults() == toErase.size() && + "expected the number of results in `op` and the size of `toErase` to " + "be the same"); + + std::vector newResultTypes; + for (OpResult result : op->getResults()) + if (!toErase[result.getResultNumber()]) + newResultTypes.push_back(result.getType()); + OpBuilder builder(op); + builder.setInsertionPointAfter(op); + OperationState state(op->getLoc(), op->getName().getStringRef(), + op->getOperands(), newResultTypes, op->getAttrs()); + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) + state.addRegion(); + Operation *newOp = builder.create(state); + for (const auto &[index, region] : llvm::enumerate(op->getRegions())) { + Region &newRegion = newOp->getRegion(index); + IRMapping mapping; + region.cloneInto(&newRegion, mapping); + } + + unsigned indexOfNextNewCallOpResultToReplace = 0; + for (auto [index, result] : llvm::enumerate(op->getResults())) { + assert(result && "expected result to be non-null"); + if (toErase[index]) { + result.dropAllUses(); + } else { + result.replaceAllUsesWith( + newOp->getResult(indexOfNextNewCallOpResultToReplace++)); + } + } + op->erase(); +} + +/// Convert a list of `Operand`s to a list of `OpOperand`s. +static SmallVector operandsToOpOperands(OperandRange operands) { + OpOperand *values = operands.getBase(); + SmallVector opOperands; + for (unsigned i = 0, e = operands.size(); i < e; i++) + opOperands.push_back(&values[i]); + return opOperands; +} + +/// Clean a simple op `op`, given the liveness analysis information in `la`. +/// Here, cleaning means: +/// (1) Dropping all its uses, AND +/// (2) Erasing it +/// iff it has no memory effects and none of its results are live. +/// +/// It is assumed that `op` is simple. Here, a simple op is one which isn't a +/// symbol op, a symbol-user op, a region branch op, a branch op, a region +/// branch terminator op, or return-like. +static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) { + if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la)) + return; + + op->dropAllUses(); + op->erase(); +} + +/// Clean a function-like op `funcOp`, given the liveness information in `la` +/// and the IR in `module`. Here, cleaning means: +/// (1) Dropping the uses of its unnecessary (non-live) arguments, +/// (2) Erasing these arguments, +/// (3) Erasing their corresponding operands from its callers, +/// (4) Erasing its unnecessary terminator operands (return values that are +/// non-live across all callers), +/// (5) Dropping the uses of these return values from its callers, AND +/// (6) Erasing these return values +/// iff it is not public. +static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module, + RunLivenessAnalysis &la) { + if (funcOp.isPublic()) + return; + + // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`. + SmallVector arguments(funcOp.getArguments()); + BitVector nonLiveArgs = markLives(arguments, la); + nonLiveArgs = nonLiveArgs.flip(); + + // Do (1). + for (auto [index, arg] : llvm::enumerate(arguments)) + if (arg && nonLiveArgs[index]) + arg.dropAllUses(); + + // Do (2). + funcOp.eraseArguments(nonLiveArgs); + + // Do (3). + SymbolTable::UseRange uses = *funcOp.getSymbolUses(module); + for (SymbolTable::SymbolUse use : uses) { + Operation *callOp = use.getUser(); + assert(isa(callOp) && "expected a call-like user"); + // The number of operands in the call op may not match the number of + // arguments in the func op. + BitVector nonLiveCallOperands(callOp->getNumOperands(), false); + SmallVector callOpOperands = + operandsToOpOperands(cast(callOp).getArgOperands()); + for (int index : nonLiveArgs.set_bits()) + nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber()); + callOp->eraseOperands(nonLiveCallOperands); + } + + // Get the list of unnecessary terminator operands (return values that are + // non-live across all callers) in `nonLiveRets`. There is a very important + // subtlety here. Unnecessary terminator operands are NOT the operands of the + // terminator that are non-live. Instead, these are the return values of the + // callers such that a given return value is non-live across all callers. Such + // corresponding operands in the terminator could be live. An example to + // demonstrate this: + // func.func private @f(%arg0: memref) -> (i32, i32) { + // %c0_i32 = arith.constant 0 : i32 + // %0 = arith.addi %c0_i32, %c0_i32 : i32 + // memref.store %0, %arg0[] : memref + // return %c0_i32, %0 : i32, i32 + // } + // func.func @main(%arg0: i32, %arg1: memref) -> (i32) { + // %1:2 = call @f(%arg1) : (memref) -> i32 + // return %1#0 : i32 + // } + // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't + // need to return %0. But, %0 is live. And, still, we want to stop it from + // being returned, in order to optimize our IR. So, this demonstrates how we + // can make our optimization strong by even removing a live return value (%0), + // since it forwards only to non-live value(s) (%1#1). + Operation *lastReturnOp = funcOp.back().getTerminator(); + size_t numReturns = lastReturnOp->getNumOperands(); + BitVector nonLiveRets(numReturns, true); + for (SymbolTable::SymbolUse use : uses) { + Operation *callOp = use.getUser(); + assert(isa(callOp) && "expected a call-like user"); + BitVector liveCallRets = markLives(callOp->getResults(), la); + nonLiveRets &= liveCallRets.flip(); + } + + // Do (4). + // Note that in the absence of control flow ops forcing the control to go from + // the entry (first) block to the other blocks, the control never reaches any + // block other than the entry block, because every block has a terminator. + for (Block &block : funcOp.getBlocks()) { + Operation *returnOp = block.getTerminator(); + if (returnOp && returnOp->getNumOperands() == numReturns) + returnOp->eraseOperands(nonLiveRets); + } + funcOp.eraseResults(nonLiveRets); + + // Do (5) and (6). + for (SymbolTable::SymbolUse use : uses) { + Operation *callOp = use.getUser(); + assert(isa(callOp) && "expected a call-like user"); + dropUsesAndEraseResults(callOp, nonLiveRets); + } +} + +/// Clean a region branch op `regionBranchOp`, given the liveness information in +/// `la`. Here, cleaning means: +/// (1') Dropping all its uses, AND +/// (2') Erasing it +/// if it has no memory effects and none of its results are live, AND +/// (1) Erasing its unnecessary operands (operands that are forwarded to +/// unneccesary results and arguments), +/// (2) Cleaning each of its regions, +/// (3) Dropping the uses of its unnecessary results (results that are +/// forwarded from unnecessary operands and terminator operands), AND +/// (4) Erasing these results +/// otherwise. +/// Note that here, cleaning a region means: +/// (2.a) Dropping the uses of its unnecessary arguments (arguments that are +/// forwarded from unneccesary operands and terminator operands), +/// (2.b) Erasing these arguments, AND +/// (2.c) Erasing its unnecessary terminator operands (terminator operands +/// that are forwarded to unneccesary results and arguments). +/// It is important to note that values in this op flow from operands and +/// terminator operands (successor operands) to arguments and results (successor +/// inputs). +static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, + RunLivenessAnalysis &la) { + // Mark live results of `regionBranchOp` in `liveResults`. + auto markLiveResults = [&](BitVector &liveResults) { + liveResults = markLives(regionBranchOp->getResults(), la); + }; + + // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. + auto markLiveArgs = [&](DenseMap &liveArgs) { + for (Region ®ion : regionBranchOp->getRegions()) { + SmallVector arguments(region.front().getArguments()); + BitVector regionLiveArgs = markLives(arguments, la); + liveArgs[®ion] = regionLiveArgs; + } + }; + + // Return the successors of `region` if the latter is not null. Else return + // the successors of `regionBranchOp`. + auto getSuccessors = [&](Region *region = nullptr) { + std::optional index = + region ? std::optional(region->getRegionNumber()) : std::nullopt; + SmallVector operandAttributes(regionBranchOp->getNumOperands(), + nullptr); + SmallVector successors; + if (!index) + regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors); + else + regionBranchOp.getSuccessorRegions(index, successors); + return successors; + }; + + // Return the operands of `terminator` that are forwarded to `successor` if + // the former is not null. Else return the operands of `regionBranchOp` + // forwarded to `successor`. + auto getForwardedOpOperands = [&](const RegionSuccessor &successor, + Operation *terminator = nullptr) { + Region *successorRegion = successor.getSuccessor(); + std::optional index = + successorRegion ? std::optional(successorRegion->getRegionNumber()) + : std::nullopt; + OperandRange operands = + terminator ? cast(terminator) + .getSuccessorOperands(index) + : regionBranchOp.getEntrySuccessorOperands(index); + SmallVector opOperands = operandsToOpOperands(operands); + return opOperands; + }; + + // Mark the non-forwarded operands of `regionBranchOp` in + // `nonForwardedOperands`. + auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { + nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); + for (const RegionSuccessor &successor : getSuccessors()) { + for (OpOperand *opOperand : getForwardedOpOperands(successor)) + nonForwardedOperands.reset(opOperand->getOperandNumber()); + } + }; + + // Mark the non-forwarded terminator operands of the various regions of + // `regionBranchOp` in `nonForwardedRets`. + auto markNonForwardedReturnValues = + [&](DenseMap &nonForwardedRets) { + for (Region ®ion : regionBranchOp->getRegions()) { + Operation *terminator = region.front().getTerminator(); + nonForwardedRets[terminator] = + BitVector(terminator->getNumOperands(), true); + for (const RegionSuccessor &successor : getSuccessors(®ion)) { + for (OpOperand *opOperand : + getForwardedOpOperands(successor, terminator)) + nonForwardedRets[terminator].reset(opOperand->getOperandNumber()); + } + } + }; + + // Update `valuesToKeep` (which is expected to correspond to operands or + // terminator operands) based on `resultsToKeep` and `argsToKeep`, given + // `region`. When `valuesToKeep` correspond to operands, `region` is null. + // Else, `region` is the parent region of the terminator. + auto updateOperandsOrTerminatorOperandsToKeep = + [&](BitVector &valuesToKeep, BitVector &resultsToKeep, + DenseMap &argsToKeep, Region *region = nullptr) { + Operation *terminator = + region ? region->front().getTerminator() : nullptr; + + for (const RegionSuccessor &successor : getSuccessors(region)) { + Region *successorRegion = successor.getSuccessor(); + for (auto [opOperand, input] : + llvm::zip(getForwardedOpOperands(successor, terminator), + successor.getSuccessorInputs())) { + size_t operandNum = opOperand->getOperandNumber(); + bool updateBasedOn = + successorRegion + ? argsToKeep[successorRegion] + [cast(input).getArgNumber()] + : resultsToKeep[cast(input).getResultNumber()]; + valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn; + } + } + }; + + // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and + // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a + // value is modified, else, false. + auto recomputeResultsAndArgsToKeep = + [&](BitVector &resultsToKeep, DenseMap &argsToKeep, + BitVector &operandsToKeep, + DenseMap &terminatorOperandsToKeep, + bool &resultsOrArgsToKeepChanged) { + resultsOrArgsToKeepChanged = false; + + // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. + for (const RegionSuccessor &successor : getSuccessors()) { + Region *successorRegion = successor.getSuccessor(); + for (auto [opOperand, input] : + llvm::zip(getForwardedOpOperands(successor), + successor.getSuccessorInputs())) { + bool recomputeBasedOn = + operandsToKeep[opOperand->getOperandNumber()]; + bool toRecompute = + successorRegion + ? argsToKeep[successorRegion] + [cast(input).getArgNumber()] + : resultsToKeep[cast(input).getResultNumber()]; + if (!toRecompute && recomputeBasedOn) + resultsOrArgsToKeepChanged = true; + if (successorRegion) { + argsToKeep[successorRegion][cast(input) + .getArgNumber()] = + argsToKeep[successorRegion] + [cast(input).getArgNumber()] | + recomputeBasedOn; + } else { + resultsToKeep[cast(input).getResultNumber()] = + resultsToKeep[cast(input).getResultNumber()] | + recomputeBasedOn; + } + } + } + + // Recompute `resultsToKeep` and `argsToKeep` based on + // `terminatorOperandsToKeep`. + for (Region ®ion : regionBranchOp->getRegions()) { + Operation *terminator = region.front().getTerminator(); + for (const RegionSuccessor &successor : getSuccessors(®ion)) { + Region *successorRegion = successor.getSuccessor(); + for (auto [opOperand, input] : + llvm::zip(getForwardedOpOperands(successor, terminator), + successor.getSuccessorInputs())) { + bool recomputeBasedOn = + terminatorOperandsToKeep[region.back().getTerminator()] + [opOperand->getOperandNumber()]; + bool toRecompute = + successorRegion + ? argsToKeep[successorRegion] + [cast(input).getArgNumber()] + : resultsToKeep[cast(input).getResultNumber()]; + if (!toRecompute && recomputeBasedOn) + resultsOrArgsToKeepChanged = true; + if (successorRegion) { + argsToKeep[successorRegion][cast(input) + .getArgNumber()] = + argsToKeep[successorRegion] + [cast(input).getArgNumber()] | + recomputeBasedOn; + } else { + resultsToKeep[cast(input).getResultNumber()] = + resultsToKeep[cast(input).getResultNumber()] | + recomputeBasedOn; + } + } + } + } + }; + + // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`, + // `operandsToKeep`, and `terminatorOperandsToKeep`. + auto markValuesToKeep = + [&](BitVector &resultsToKeep, DenseMap &argsToKeep, + BitVector &operandsToKeep, + DenseMap &terminatorOperandsToKeep) { + bool resultsOrArgsToKeepChanged = true; + // We keep updating and recomputing the values until we reach a point + // where they stop changing. + while (resultsOrArgsToKeepChanged) { + // Update the operands that need to be kept. + updateOperandsOrTerminatorOperandsToKeep(operandsToKeep, + resultsToKeep, argsToKeep); + + // Update the terminator operands that need to be kept. + for (Region ®ion : regionBranchOp->getRegions()) { + updateOperandsOrTerminatorOperandsToKeep( + terminatorOperandsToKeep[region.back().getTerminator()], + resultsToKeep, argsToKeep, ®ion); + } + + // Recompute the results and arguments that need to be kept. + recomputeResultsAndArgsToKeep( + resultsToKeep, argsToKeep, operandsToKeep, + terminatorOperandsToKeep, resultsOrArgsToKeepChanged); + } + }; + + // Do (1') and (2'). This is the only case where the entire `regionBranchOp` + // is removed. It will not happen in any other scenario. Note that in this + // case, a non-forwarded operand of `regionBranchOp` could be live/non-live. + // It could never be live because of this op but its liveness could have been + // attributed to something else. + if (isMemoryEffectFree(regionBranchOp.getOperation()) && + !hasLive(regionBranchOp->getResults(), la)) { + regionBranchOp->dropAllUses(); + regionBranchOp->erase(); + return; + } + + // At this point, we know that every non-forwarded operand of `regionBranchOp` + // is live. + + // Stores the results of `regionBranchOp` that we want to keep. + BitVector resultsToKeep; + // Stores the mapping from regions of `regionBranchOp` to their arguments that + // we want to keep. + DenseMap argsToKeep; + // Stores the operands of `regionBranchOp` that we want to keep. + BitVector operandsToKeep; + // Stores the mapping from region terminators in `regionBranchOp` to their + // operands that we want to keep. + DenseMap terminatorOperandsToKeep; + + // Initializing the above variables... + + // The live results of `regionBranchOp` definitely need to be kept. + markLiveResults(resultsToKeep); + // Similarly, the live arguments of the regions in `regionBranchOp` definitely + // need to be kept. + markLiveArgs(argsToKeep); + // The non-forwarded operands of `regionBranchOp` definitely need to be kept. + // A live forwarded operand can be removed but no non-forwarded operand can be + // removed since it "controls" the flow of data in this control flow op. + markNonForwardedOperands(operandsToKeep); + // Similarly, the non-forwarded terminator operands of the regions in + // `regionBranchOp` definitely need to be kept. + markNonForwardedReturnValues(terminatorOperandsToKeep); + + // Mark the values (results, arguments, operands, and terminator operands) + // that we want to keep. + markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep, + terminatorOperandsToKeep); + + // Do (1). + regionBranchOp->eraseOperands(operandsToKeep.flip()); + + // Do (2.a) and (2.b). + for (Region ®ion : regionBranchOp->getRegions()) { + assert(!region.empty() && "expected a non-empty region in an op " + "implementing `RegionBranchOpInterface`"); + for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) { + if (argsToKeep[®ion][index]) + continue; + if (arg) + arg.dropAllUses(); + } + region.front().eraseArguments(argsToKeep[®ion].flip()); + } + + // Do (2.c). + for (Region ®ion : regionBranchOp->getRegions()) { + Operation *terminator = region.front().getTerminator(); + terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip()); + } + + // Do (3) and (4). + dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip()); +} + +struct RemoveDeadValues : public impl::RemoveDeadValuesBase { + void runOnOperation() override; +}; +} // namespace + +void RemoveDeadValues::runOnOperation() { + auto &la = getAnalysis(); + Operation *module = getOperation(); + + // The removal of non-live values is performed iff there are no branch ops, + // all symbol ops present in the IR are function-like, and all symbol user ops + // present in the IR are call-like. + WalkResult acceptableIR = module->walk([&](Operation *op) { + if (isa(op) || + (isa(op) && !isa(op)) || + (isa(op) && !isa(op))) { + op->emitError() << "cannot optimize an IR with non-function symbol ops, " + "non-call symbol user ops or branch ops\n"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (acceptableIR.wasInterrupted()) + return; + + module->walk([&](Operation *op) { + if (auto funcOp = dyn_cast(op)) { + cleanFuncOp(funcOp, module, la); + } else if (auto regionBranchOp = dyn_cast(op)) { + cleanRegionBranchOp(regionBranchOp, la); + } else if (op->hasTrait()) { + // Nothing to do because this terminator is associated with either a + // function op or a region branch op and gets cleaned when these ops are + // cleaned. + } else if (isa(op)) { + // Nothing to do because this terminator is associated with a region + // branch op and gets cleaned when the latter is cleaned. + } else if (isa(op)) { + // Nothing to do because this op is associated with a function op and gets + // cleaned when the latter is cleaned. + } else { + cleanSimpleOp(op, la); + } + }); +} + +std::unique_ptr mlir::createRemoveDeadValuesPass() { + return std::make_unique(); +} diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -0,0 +1,330 @@ +// RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics | FileCheck %s + +// The IR remains untouched because of the presence of a non-function-like +// symbol op (module @dont_touch_unacceptable_ir). +// +// expected-error @+1 {{cannot optimize an IR with non-function symbol ops, non-call symbol user ops or branch ops}} +module @dont_touch_unacceptable_ir { + func.func @has_cleanable_simple_op(%arg0 : i32) { + %non_live = arith.addi %arg0, %arg0 : i32 + return + } +} + +// ----- + +// The IR remains untouched because of the presence of a branch op `cf.cond_br`. +// +func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) { + %non_live = arith.constant 0 : i32 + // expected-error @+1 {{cannot optimize an IR with non-function symbol ops, non-call symbol user ops or branch ops}} + cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32) +^bb1(%non_live_0 : i32): + cf.br ^bb3 +^bb2(%non_live_1 : i32): + cf.br ^bb3 +^bb3: + return +} + +// ----- + +// Note that this cleanup cannot be done by the `canonicalize` pass. +// +// CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() { +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func.func @main(%[[arg0:.*]]: i32) { +// CHECK-NEXT: call @clean_func_op_remove_argument_and_return_value() : () -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func private @clean_func_op_remove_argument_and_return_value(%arg0: i32) -> (i32) { + return %arg0 : i32 +} +func.func @main(%arg0 : i32) { + %non_live = func.call @clean_func_op_remove_argument_and_return_value(%arg0) : (i32) -> (i32) + return +} + +// ----- + +// %arg0 is not live because it is never used. %arg1 is not live because its +// user `arith.addi` doesn't have any uses and the value that it is forwarded to +// (%non_live_0) also doesn't have any uses. +// +// Note that this cleanup cannot be done by the `canonicalize` pass. +// +// CHECK-LABEL: func.func private @clean_func_op_remove_arguments() -> i32 { +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-NEXT: return %[[c0]] +// CHECK-NEXT: } +// CHECK: func.func @main(%[[arg2:.*]]: memref, %[[arg3:.*]]: i32, %[[DEVICE:.*]]: i32) -> (i32, memref) { +// CHECK-NEXT: %[[live:.*]] = test.call_on_device @clean_func_op_remove_arguments(), %[[DEVICE]] : (i32) -> i32 +// CHECK-NEXT: return %[[live]], %[[arg2]] +// CHECK-NEXT: } +func.func private @clean_func_op_remove_arguments(%arg0 : memref, %arg1 : i32) -> (i32, i32) { + %c0 = arith.constant 0 : i32 + %non_live = arith.addi %arg1, %arg1 : i32 + return %c0, %arg1 : i32, i32 +} +func.func @main(%arg2 : memref, %arg3 : i32, %device : i32) -> (i32, memref) { + %live, %non_live_0 = test.call_on_device @clean_func_op_remove_arguments(%arg2, %arg3), %device : (memref, i32, i32) -> (i32, i32) + return %live, %arg2 : i32, memref +} + +// ----- + +// Even though %non_live_0 is not live, the first return value of +// @clean_func_op_remove_return_values isn't removed because %live is live +// (liveness is checked across all callers). +// +// Also, the second return value of @clean_func_op_remove_return_values is +// removed despite %c0 being live because neither %non_live nor %non_live_1 were +// live (removal doesn't depend on the liveness of the operand itself but on the +// liveness of where it is forwarded). +// +// Note that this cleanup cannot be done by the `canonicalize` pass. +// +// CHECK: func.func private @clean_func_op_remove_return_values(%[[arg0:.*]]: memref) -> i32 { +// CHECK-NEXT: %[[c0]] = arith.constant 0 +// CHECK-NEXT: memref.store %[[c0]], %[[arg0]][] +// CHECK-NEXT: return %[[c0]] +// CHECK-NEXT: } +// CHECK: func.func @main(%[[arg1:.*]]: memref) -> i32 { +// CHECK-NEXT: %[[live:.*]] = call @clean_func_op_remove_return_values(%[[arg1]]) : (memref) -> i32 +// CHECK-NEXT: %[[non_live_0:.*]] = call @clean_func_op_remove_return_values(%[[arg1]]) : (memref) -> i32 +// CHECK-NEXT: return %[[live]] : i32 +// CHECK-NEXT: } +func.func private @clean_func_op_remove_return_values(%arg0 : memref) -> (i32, i32) { + %c0 = arith.constant 0 : i32 + memref.store %c0, %arg0[] : memref + return %c0, %c0 : i32, i32 +} +func.func @main(%arg1 : memref) -> (i32) { + %live, %non_live = func.call @clean_func_op_remove_return_values(%arg1) : (memref) -> (i32, i32) + %non_live_0, %non_live_1 = func.call @clean_func_op_remove_return_values(%arg1) : (memref) -> (i32, i32) + return %live : i32 +} + +// ----- + +// None of the return values of @clean_func_op_dont_remove_return_values can be +// removed because the first one is forwarded to a live value %live and the +// second one is forwarded to a live value %live_0. +// +// CHECK-LABEL: func.func private @clean_func_op_dont_remove_return_values() -> (i32, i32) { +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : i32 +// CHECK-NEXT: return %[[c0]], %[[c0]] : i32, i32 +// CHECK-NEXT: } +// CHECK-LABEL: func.func @main() -> (i32, i32) { +// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = call @clean_func_op_dont_remove_return_values() : () -> (i32, i32) +// CHECK-NEXT: %[[non_live_0_and_live_0:.*]]:2 = call @clean_func_op_dont_remove_return_values() : () -> (i32, i32) +// CHECK-NEXT: return %[[live_and_non_live]]#0, %[[non_live_0_and_live_0]]#1 : i32, i32 +// CHECK-NEXT: } +func.func private @clean_func_op_dont_remove_return_values() -> (i32, i32) { + %c0 = arith.constant 0 : i32 + return %c0, %c0 : i32, i32 +} +func.func @main() -> (i32, i32) { + %live, %non_live = func.call @clean_func_op_dont_remove_return_values() : () -> (i32, i32) + %non_live_0, %live_0 = func.call @clean_func_op_dont_remove_return_values() : () -> (i32, i32) + return %live, %live_0 : i32, i32 +} + +// ----- + +// Values kept: +// (1) %non_live is not live. Yet, it is kept because %arg4 in `scf.condition` +// forwards to it, which has to be kept. %arg4 in `scf.condition` has to be +// kept because it forwards to %arg6 which is live. +// +// (2) %arg5 is not live. Yet, it is kept because %live_0 forwards to it, which +// also forwards to %live, which is live. +// +// Values not kept: +// (1) %arg1 is not kept as an operand of `scf.while` because it only forwards +// to %arg3, which is not kept. %arg3 is not kept because %arg3 is not live and +// only %arg1 and %arg7 forward to it, such that neither of them forward +// anywhere else. Thus, %arg7 is also not kept in the `scf.yield` op. +// +// Note that this cleanup cannot be done by the `canonicalize` pass. +// +// CHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { +// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { +// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]] +// CHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): +// CHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]] +// CHECK-NEXT: scf.yield %[[live_1]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[live_and_non_live]]#0 +// CHECK-NEXT: } +func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) { + %live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) { + %live_0 = arith.addi %arg4, %arg4 : i32 + %non_live_1 = arith.addi %arg3, %arg3 : i32 + scf.condition(%arg0) %live_0, %arg4, %non_live_1 : i32, i32, i32 + } do { + ^bb0(%arg5: i32, %arg6: i32, %arg7: i32): + %live_1 = arith.addi %arg6, %arg6 : i32 + scf.yield %arg7, %live_1 : i32, i32 + } + return %live : i32 +} + +// ----- + +// Values kept: +// (1) %live is kept because it is live. +// +// (2) %non_live is not live. Yet, it is kept because %arg3 in `scf.condition` +// forwards to it and this %arg3 has to be kept. This %arg3 in `scf.condition` +// has to be kept because it forwards to %arg6, which forwards to %arg4, which +// forwards to %live, which is live. +// +// Values not kept: +// (1) %non_live_0 is not kept because %non_live_2 in `scf.condition` forwards +// to it, which forwards to only %non_live_0 and %arg7, where both these are +// not live and have no other value forwarding to them. +// +// (2) %non_live_1 is not kept because %non_live_3 in `scf.condition` forwards +// to it, which forwards to only %non_live_1 and %arg8, where both these are +// not live and have no other value forwarding to them. +// +// (3) %c2 is not kept because it only forwards to %arg10, which is not kept. +// +// (4) %arg10 is not kept because only %c2 and %non_live_4 forward to it, none +// of them forward anywhere else, and %arg10 is not. +// +// (5) %arg7 and %arg8 are not kept because they are not live, %non_live_2 and +// %non_live_3 forward to them, and both only otherwise forward to %non_live_0 +// and %non_live_1 which are not live and have no other predecessors. +// +// Note that this cleanup cannot be done by the `canonicalize` pass. +// +// CHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 { +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 +// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { +// CHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): +// CHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[live_and_non_live]]#0 : i32 +// CHECK-NEXT: } +func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %live, %non_live, %non_live_0, %non_live_1 = scf.while (%arg3 = %c0, %arg4 = %c1, %arg10 = %c2) : (i32, i32, i32) -> (i32, i32, i32, i32) { + %non_live_2 = arith.addi %arg10, %arg10 : i32 + %non_live_3 = arith.muli %arg10, %arg10 : i32 + scf.condition(%arg2) %arg4, %arg3, %non_live_2, %non_live_3 : i32, i32, i32, i32 + } do { + ^bb0(%arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32): + %non_live_4 = arith.addi %arg7, %arg8 :i32 + scf.yield %arg5, %arg6, %non_live_4 : i32, i32, i32 + } + return %live : i32 +} + +// ----- + +// The op isn't erased because it has memory effects but its unnecessary result +// is removed. +// +// Note that this cleanup cannot be done by the `canonicalize` pass. +// +// CHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref) { +// CHECK-NEXT: scf.index_switch %[[arg0]] +// CHECK-NEXT: case 1 { +// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 +// CHECK-NEXT: memref.store %[[c10]], %[[arg1]][] +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: default { +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref) { + %non_live = scf.index_switch %arg0 -> i32 + case 1 { + %c10 = arith.constant 10 : i32 + memref.store %c10, %arg1[] : memref + scf.yield %c10 : i32 + } + default { + %c11 = arith.constant 11 : i32 + scf.yield %c11 : i32 + } + return +} + +// ----- + +// The simple ops which don't have memory effects or live results get removed. +// %arg5 doesn't get removed from the @main even though it isn't live because +// the signature of a public function is always left untouched. +// +// Note that this cleanup cannot be done by the `canonicalize` pass. +// +// CHECK: func.func private @clean_simple_ops(%[[arg0:.*]]: i32, %[[arg1:.*]]: memref) +// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg0]], %[[arg0]] +// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 +// CHECK-NEXT: %[[live_1:.*]] = arith.muli %[[live_0]], %[[c2]] +// CHECK-NEXT: %[[c3:.*]] = arith.constant 3 +// CHECK-NEXT: %[[live_2:.*]] = arith.addi %[[arg0]], %[[c3]] +// CHECK-NEXT: memref.store %[[live_2]], %[[arg1]][] +// CHECK-NEXT: return %[[live_1]] +// CHECK-NEXT: } +// CHECK: func.func @main(%[[arg3:.*]]: i32, %[[arg4:.*]]: memref, %[[arg5:.*]] +// CHECK-NEXT: %[[live:.*]] = call @clean_simple_ops(%[[arg3]], %[[arg4]]) +// CHECK-NEXT: return %[[live]] +// CHECK-NEXT: } +func.func private @clean_simple_ops(%arg0 : i32, %arg1 : memref, %arg2 : i32) -> (i32, i32, i32, i32) { + %live_0 = arith.addi %arg0, %arg0 : i32 + %c2 = arith.constant 2 : i32 + %live_1 = arith.muli %live_0, %c2 : i32 + %non_live_1 = arith.addi %live_1, %live_0 : i32 + %non_live_2 = arith.constant 7 : i32 + %non_live_3 = arith.subi %arg0, %non_live_1 : i32 + %c3 = arith.constant 3 : i32 + %live_2 = arith.addi %arg0, %c3 : i32 + memref.store %live_2, %arg1[] : memref + return %live_1, %non_live_1, %non_live_2, %non_live_3 : i32, i32, i32, i32 +} + +func.func @main(%arg3 : i32, %arg4 : memref, %arg5 : i32) -> (i32) { + %live, %non_live_1, %non_live_2, %non_live_3 = func.call @clean_simple_ops(%arg3, %arg4, %arg5) : (i32, memref, i32) -> (i32, i32, i32, i32) + return %live : i32 +} + +// ----- + +// The scf.while op has no memory effects and its result isn't live. +// +// Note that this cleanup cannot be done by the `canonicalize` pass. +// +// CHECK-LABEL: func.func private @clean_region_branch_op_erase_it() { +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func.func @main(%[[arg3:.*]]: i32, %[[arg4:.*]]: i1) { +// CHECK-NEXT: call @clean_region_branch_op_erase_it() : () -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func private @clean_region_branch_op_erase_it(%arg0 : i32, %arg1 : i1) -> (i32) { + %non_live = scf.while (%arg2 = %arg0) : (i32) -> (i32) { + scf.condition(%arg1) %arg2 : i32 + } do { + ^bb0(%arg2: i32): + scf.yield %arg2 : i32 + } + return %non_live : i32 +} + +func.func @main(%arg3 : i32, %arg4 : i1) { + %non_live_0 = func.call @clean_region_branch_op_erase_it(%arg3, %arg4) : (i32, i1) -> (i32) + return +}