diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -19,6 +19,7 @@ #include "mlir/Transforms/ViewOpGraph.h" #include "llvm/Support/Debug.h" #include +#include namespace mlir { @@ -105,6 +106,9 @@ createInlinerPass(llvm::StringMap opPipelines, std::function defaultPipelineBuilder); +/// Creates an optimization pass to remove non-live values. +std::unique_ptr createRemoveNonLiveValuesPass(); + /// Creates a pass which performs sparse conditional constant propagation over /// nested operations. std::unique_ptr createSCCPPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -85,6 +85,19 @@ ]; } +def RemoveNonLiveValues : Pass<"remove-non-live-values"> { + let summary = "Remove non-live values"; + let description = [{ + This pass optimizes the IR by removing non-live values. The removal of such + values involves removal of extraneous operands, extraneous arguments, + extraneous return values, and even extraneous ops, all of which + unnecessarily increase the runtime, in theory. The pass relies on the + information provided by the liveness analysis utility to accomplish its + optimization goal. + }]; + let constructor = "mlir::createRemoveNonLiveValuesPass()"; +} + def PrintIRPass : Pass<"print-ir"> { let summary = "Print IR on the debug stream"; let description = [{ diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ Mem2Reg.cpp OpStats.cpp PrintIR.cpp + RemoveNonLiveValues.cpp SCCP.cpp SROA.cpp StripDebugInfo.cpp diff --git a/mlir/lib/Transforms/RemoveNonLiveValues.cpp b/mlir/lib/Transforms/RemoveNonLiveValues.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/RemoveNonLiveValues.cpp @@ -0,0 +1,603 @@ +//===- RemoveNonLiveValues.cpp - Remove Non-Live Values -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass optimizes the IR by removing non-live values. The removal of such +// values involves removal of extraneous operands, extraneous arguments, +// extraneous return values, and even extraneous ops, all of which unnecessarily +// increase the runtime, in theory. The pass relies on the information provided +// by the liveness analysis utility to accomplish its optimization goal. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlir { +#define GEN_PASS_DEF_REMOVENONLIVEVALUES +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::dataflow; + +//===----------------------------------------------------------------------===// +// RemoveNonLiveValues Pass +//===----------------------------------------------------------------------===// + +namespace { + +// Some helper functions... + +/// Return true iff at least one value in `values` is live, given the liveness +/// information in `la`. +static bool hasLive(const SmallVector &values, + RunLivenessAnalysis &la) { + for (Value value : values) { + // If there is a null value, it implies that it was dropped during the + // execution of this pass, implying that it was non-live. + if (!value) + continue; + + const Liveness *liveness = la.getLiveness(value); + if (!liveness || liveness->isLive) + return true; + } + return false; +} + +/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the +/// i-th value in `values` is live, given the liveness information in `la`. +static BitVector markLives(const SmallVector &values, + RunLivenessAnalysis &la) { + BitVector lives(values.size(), true); + + for (auto it : llvm::enumerate(values)) { + Value value = it.value(); + size_t index = it.index(); + + if (!value) { + lives.reset(index); + continue; + } + + const Liveness *liveness = la.getLiveness(value); + // It is important to note that when `liveness` is null, we can't tell if + // `value` is live or not. So, the safe option is to consider it live. Also, + // the execution of this pass might create new SSA values when erasing some + // of the results of an op and we know that these new values are live + // (because they weren't erased) and also their liveness is null because + // liveness analysis ran before their creation. + if (liveness && !liveness->isLive) + lives.reset(index); + } + + return lives; +} + +/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i] +/// is 1. +static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { + assert(op->getNumResults() == toErase.size() && + "expected the number of results in `op` and the size of `toErase` to " + "be the same"); + + std::vector newResultTypes; + for (OpResult result : op->getResults()) { + if (!toErase[result.getResultNumber()]) { + newResultTypes.push_back(result.getType()); + } + } + OpBuilder builder(op); + builder.setInsertionPointAfter(op); + OperationState state(op->getLoc(), op->getName().getStringRef(), + op->getOperands(), newResultTypes, op->getAttrs()); + for (unsigned i = 0; i < op->getNumRegions(); ++i) { + state.addRegion(); + } + Operation *newOp = builder.create(state); + for (const auto &indexed_regions : llvm::enumerate(op->getRegions())) { + Region ®ion = newOp->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(®ion, mapping); + } + + unsigned indexOfNextNewCallOpResultToReplace = 0; + for (auto it : llvm::enumerate(op->getResults())) { + Value result = it.value(); + size_t index = it.index(); + + assert(result && "expected result to be non-null"); + if (toErase[index]) { + result.dropAllUses(); + } else { + result.replaceAllUsesWith( + newOp->getResult(indexOfNextNewCallOpResultToReplace++)); + } + } + op->erase(); +} + +/// Convert a list of `Operand`s to a list of `OpOperand`s. This function is +/// borrowed from the Analysis/DataFlow/SparseAnalysis.cpp file. +static MutableArrayRef operandsToOpOperands(OperandRange &operands) { + return MutableArrayRef(operands.getBase(), operands.size()); +} + +/// Clean a simple op `op`, given the liveness analysis information in `la`. +/// Here, cleaning means: +/// (1) Dropping all its uses, AND +/// (2) Erasing it +/// iff it has no memory effects and none of its results are live. +/// +/// It is assumed that `op` is simple. Here, a simple op is one which isn't a +/// symbol op, a symbol-user op, a region branch op, a branch op, a region +/// branch terminator op, or return-like. +static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) { + if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la)) + return; + + op->dropAllUses(); + op->erase(); +} + +/// Clean a function-like op `funcOp`, given the liveness information in `la` +/// and the IR in `irOp`. Here, cleaning means: +/// (1) Dropping the uses of its unnecessary (non-live) arguments, +/// (2) Erasing these arguments, +/// (3) Erasing their corresponding operands from its callers, +/// (4) Erasing its unnecessary terminator operands (return values that are +/// non-live across all callers), +/// (5) Dropping the uses of these return values from its callers, AND +/// (6) Erasing these resturn values +/// iff it is not public. +static void cleanFuncOp(FunctionOpInterface funcOp, Operation *irOp, + RunLivenessAnalysis &la) { + if (funcOp.isPublic()) + return; + + // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`. + SmallVector arguments(funcOp.getArguments()); + BitVector nonLiveArgs = markLives(arguments, la); + nonLiveArgs = nonLiveArgs.flip(); + + // Do (1). + for (auto it : llvm::enumerate(arguments)) { + Value arg = it.value(); + if (arg && nonLiveArgs[it.index()]) + arg.dropAllUses(); + } + + // Do (2). + funcOp.eraseArguments(nonLiveArgs); + + // Do (3). + SymbolTable::UseRange uses = *funcOp.getSymbolUses(irOp); + for (SymbolTable::SymbolUse use : uses) { + Operation *callOp = use.getUser(); + assert(isa(callOp) && "expected a call-like user"); + callOp->eraseOperands(nonLiveArgs); + } + + // Get the list of unnecessary terminator operands (return values that are + // non-live across all callers) in `nonLiveRets`. There is a very important + // subtlety here. Unnecessary terminator operands are NOT the operands of the + // terminator that are non-live. Instead, these are the return values of the + // callers such that a given return value is non-live across all callers. Such + // corresponding operands in the terminator could be live. An example to + // demonstrate this: + // func.func private @f(%arg0: memref) -> (i32, i32) { + // %c0_i32 = arith.constant 0 : i32 + // %0 = arith.addi %c0_i32, %c0_i32 : i32 + // memref.store %0, %arg0[] : memref + // return %c0_i32, %0 : i32, i32 + // } + // func.func @main(%arg0: i32, %arg1: memref) -> (i32) { + // %1:2 = call @f(%arg1) : (memref) -> i32 + // return %1#0 : i32 + // } + // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't + // need to return %0. But, %0 is live. And, still, we want to stop it from + // being returned, in order to optimize our IR. So, this demonstrates how we + // can make our optimization strong by even removing a live return value (%0), + // since it forwards only to non-live value(s) (%1#1). + Operation *lastReturnOp = funcOp.back().getTerminator(); + size_t numReturns = lastReturnOp->getNumOperands(); + BitVector nonLiveRets(numReturns, true); + for (SymbolTable::SymbolUse use : uses) { + Operation *callOp = use.getUser(); + assert(isa(callOp) && "expected a call-like user"); + BitVector liveCallRets = markLives(callOp->getResults(), la); + nonLiveRets &= liveCallRets.flip(); + } + + // Do (4). + // Note that in the absence of control flow ops forcing the control to go from + // the entry (first) block to the other blocks, the control never reaches any + // block other than the entry block, because every block has a terminator. + for (Block &block : funcOp.getBlocks()) { + Operation *returnOp = block.getTerminator(); + if (returnOp && returnOp->getNumOperands() == numReturns) + returnOp->eraseOperands(nonLiveRets); + } + funcOp.eraseResults(nonLiveRets); + + // Do (5) and (6). + for (SymbolTable::SymbolUse use : uses) { + Operation *callOp = use.getUser(); + assert(isa(callOp) && "expected a call-like user"); + dropUsesAndEraseResults(callOp, nonLiveRets); + } +} + +/// Clean a region branch op `regionBranchOp`, given the liveness information in +/// `la`. Here, cleaning means: +/// (1') Dropping all its uses, AND +/// (2') Erasing it +/// if it has no memory effects and none of its results are live, AND +/// (1) Erasing its unnecessary operands (operands that are forwarded to +/// unneccesary results and arguments), +/// (2) Cleaning each of its regions, +/// (3) Dropping the uses of its unnecessary results (results that are +/// forwarded from unnecessary operands and terminator operands), AND +/// (4) Erasing these results +/// otherwise. +/// Note that here, cleaning a region means: +/// (2.a) Dropping the uses of its unnecessary arguments (arguments that are +/// forwarded from unneccesary operands and terminator operands), +/// (2.b) Erasing these arguments, AND +/// (2.c) Erasing its unnecessary terminator operands (terminator operands +/// that are forwarded to unneccesary results and arguments). +/// It is important to note that values in this op flow from operands and +/// terminator operands (successor operands) to arguments and results (successor +/// inputs). +static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, + RunLivenessAnalysis &la) { + // Mark live results of `regionBranchOp` in `liveResults`. + auto markLiveResults = [&](BitVector &liveResults) { + liveResults = markLives(regionBranchOp->getResults(), la); + }; + + // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. + auto markLiveArgs = [&](DenseMap &liveArgs) { + for (Region ®ion : regionBranchOp->getRegions()) { + SmallVector arguments(region.front().getArguments()); + BitVector regionLiveArgs = markLives(arguments, la); + liveArgs[®ion] = regionLiveArgs; + } + }; + + // Return the successors of `region` if the latter is not null. Else return + // the successors of `regionBranchOp`. + auto getSuccessors = [&](Region *region = nullptr) { + std::optional index = + region ? std::optional(region->getRegionNumber()) : std::nullopt; + SmallVector operandAttributes(regionBranchOp->getNumOperands(), + nullptr); + SmallVector successors; + regionBranchOp.getSuccessorRegions(index, operandAttributes, successors); + return successors; + }; + + // Return the operands of `terminator` that are forwarded to `successor` if + // the former is not null. Else return the operands of `regionBranchOp` + // forwarded to `successor`. + auto getForwardedOpOperands = [&](const RegionSuccessor &successor, + Operation *terminator = nullptr) { + Region *successorRegion = successor.getSuccessor(); + std::optional index = + successorRegion ? std::optional(successorRegion->getRegionNumber()) + : std::nullopt; + OperandRange operands = + terminator ? *getRegionBranchSuccessorOperands(terminator, index) + : regionBranchOp.getSuccessorEntryOperands(index); + MutableArrayRef opOperands = operandsToOpOperands(operands); + return opOperands; + }; + + // Mark the non-forwarded operands of `regionBranchOp` in + // `nonForwardedOperands`. + auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { + nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); + for (const RegionSuccessor &successor : getSuccessors()) { + for (OpOperand &opOperand : getForwardedOpOperands(successor)) + nonForwardedOperands.reset(opOperand.getOperandNumber()); + } + }; + + // Mark the non-forwarded terminator operands of the various regions of + // `regionBranchOp` in `nonForwardedRets`. + auto markNonForwardedReturnValues = + [&](DenseMap &nonForwardedRets) { + for (Region ®ion : regionBranchOp->getRegions()) { + Operation *terminator = region.front().getTerminator(); + nonForwardedRets[terminator] = + BitVector(terminator->getNumOperands(), true); + for (const RegionSuccessor &successor : getSuccessors(®ion)) { + for (OpOperand &opOperand : + getForwardedOpOperands(successor, terminator)) + nonForwardedRets[terminator].reset(opOperand.getOperandNumber()); + } + } + }; + + // Update `valuesToKeep` (which is expected to correspond to operands or + // terminator operands) based on `resultsToKeep` and `argsToKeep`, given + // `region`. When `valuesToKeep` correspond to operands, `region` is null. + // Else, `region` is the parent region of the terminator. + auto updateOperandsOrTerminatorOperandsToKeep = + [&](BitVector &valuesToKeep, BitVector &resultsToKeep, + DenseMap &argsToKeep, Region *region = nullptr) { + Operation *terminator = + region ? region->front().getTerminator() : nullptr; + + for (const RegionSuccessor &successor : getSuccessors(region)) { + Region *successorRegion = successor.getSuccessor(); + for (auto [opOperand, input] : + llvm::zip(getForwardedOpOperands(successor, terminator), + successor.getSuccessorInputs())) { + size_t operandNum = opOperand.getOperandNumber(); + bool updateBasedOn = + successorRegion + ? argsToKeep[successorRegion] + [cast(input).getArgNumber()] + : resultsToKeep[cast(input).getResultNumber()]; + valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn; + } + } + }; + + // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and + // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a + // value is modified, else, false. + auto recomputeResultsAndArgsToKeep = + [&](BitVector &resultsToKeep, DenseMap &argsToKeep, + BitVector &operandsToKeep, + DenseMap &terminatorOperandsToKeep, + bool &resultsOrArgsToKeepChanged) { + resultsOrArgsToKeepChanged = false; + + // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. + for (const RegionSuccessor &successor : getSuccessors()) { + Region *successorRegion = successor.getSuccessor(); + for (auto [opOperand, input] : + llvm::zip(getForwardedOpOperands(successor), + successor.getSuccessorInputs())) { + bool recomputeBasedOn = + operandsToKeep[opOperand.getOperandNumber()]; + bool toRecompute = + successorRegion + ? argsToKeep[successorRegion] + [cast(input).getArgNumber()] + : resultsToKeep[cast(input).getResultNumber()]; + if (!toRecompute && recomputeBasedOn) + resultsOrArgsToKeepChanged = true; + if (successorRegion) { + argsToKeep[successorRegion][cast(input) + .getArgNumber()] = + argsToKeep[successorRegion] + [cast(input).getArgNumber()] | + recomputeBasedOn; + } else { + resultsToKeep[cast(input).getResultNumber()] = + resultsToKeep[cast(input).getResultNumber()] | + recomputeBasedOn; + } + } + } + + // Recompute `resultsToKeep` and `argsToKeep` based on + // `terminatorOperandsToKeep`. + for (Region ®ion : regionBranchOp->getRegions()) { + Operation *terminator = region.front().getTerminator(); + for (const RegionSuccessor &successor : getSuccessors(®ion)) { + Region *successorRegion = successor.getSuccessor(); + for (auto [opOperand, input] : + llvm::zip(getForwardedOpOperands(successor, terminator), + successor.getSuccessorInputs())) { + bool recomputeBasedOn = + terminatorOperandsToKeep[region.back().getTerminator()] + [opOperand.getOperandNumber()]; + bool toRecompute = + successorRegion + ? argsToKeep[successorRegion] + [cast(input).getArgNumber()] + : resultsToKeep[cast(input).getResultNumber()]; + if (!toRecompute && recomputeBasedOn) + resultsOrArgsToKeepChanged = true; + if (successorRegion) { + argsToKeep[successorRegion][cast(input) + .getArgNumber()] = + argsToKeep[successorRegion] + [cast(input).getArgNumber()] | + recomputeBasedOn; + } else { + resultsToKeep[cast(input).getResultNumber()] = + resultsToKeep[cast(input).getResultNumber()] | + recomputeBasedOn; + } + } + } + } + }; + + // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`, + // `operandsToKeep`, and `terminatorOperandsToKeep`. + auto markValuesToKeep = + [&](BitVector &resultsToKeep, DenseMap &argsToKeep, + BitVector &operandsToKeep, + DenseMap &terminatorOperandsToKeep) { + bool resultsOrArgsToKeepChanged = true; + // We keep updating and recomputing the values until we reach a point + // where they stop changing. + while (resultsOrArgsToKeepChanged) { + // Update the operands that need to be kept. + updateOperandsOrTerminatorOperandsToKeep(operandsToKeep, + resultsToKeep, argsToKeep); + + // Update the terminator operands that need to be kept. + for (Region ®ion : regionBranchOp->getRegions()) { + updateOperandsOrTerminatorOperandsToKeep( + terminatorOperandsToKeep[region.back().getTerminator()], + resultsToKeep, argsToKeep, ®ion); + } + + // Recompute the results and arguments that need to be kept. + recomputeResultsAndArgsToKeep( + resultsToKeep, argsToKeep, operandsToKeep, + terminatorOperandsToKeep, resultsOrArgsToKeepChanged); + } + }; + + // Do (1') and (2'). This is the only case where the entire `regionBranchOp` + // is removed. It will not happen in any other scenario. Note that in this + // case, a non-forwarded operand of `regionBranchOp` could be live/non-live. + // It could never be live because of this op but its liveness could have been + // attributed to something else. + if (isMemoryEffectFree(regionBranchOp.getOperation()) && + !hasLive(regionBranchOp->getResults(), la)) { + regionBranchOp->dropAllUses(); + regionBranchOp->erase(); + return; + } + + // At this point, we know that every non-forwarded operand of `regionBranchOp` + // is live. + + // Stores the results of `regionBranchOp` that we want to keep. + BitVector resultsToKeep; + // Stores the mapping from regions of `regionBranchOp` to their arguments that + // we want to keep. + DenseMap argsToKeep; + // Stores the operands of `regionBranchOp` that we want to keep. + BitVector operandsToKeep; + // Stores the mapping from region terminators in `regionBranchOp` to their + // operands that we want to keep. + DenseMap terminatorOperandsToKeep; + + // Initializing the above variables... + + // The live results of `regionBranchOp` definitely need to be kept. + markLiveResults(resultsToKeep); + // Similarly, the live arguments of the regions in `regionBranchOp` definitely + // need to be kept. + markLiveArgs(argsToKeep); + // The non-forwarded operands of `regionBranchOp` definitely need to be kept. + // A live forwarded operand can be removed but no non-forwarded operand can be + // removed since it "controls" the flow of data in this control flow op. + markNonForwardedOperands(operandsToKeep); + // Similarly, the non-forwarded terminator operands of the regions in + // `regionBranchOp` definitely need to be kept. + markNonForwardedReturnValues(terminatorOperandsToKeep); + + // Mark the values (results, arguments, operands, and terminator operands) + // that we want to keep. + markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep, + terminatorOperandsToKeep); + + // Do (1). + regionBranchOp->eraseOperands(operandsToKeep.flip()); + + // Do (2.a) and (2.b). + for (Region ®ion : regionBranchOp->getRegions()) { + assert(!region.empty() && "expected a non-empty region in an op " + "implementing `RegionBranchOpInterface`"); + for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) { + if (argsToKeep[®ion][index]) + continue; + if (arg) + arg.dropAllUses(); + } + region.front().eraseArguments(argsToKeep[®ion].flip()); + } + + // Do (2.c). + for (Region ®ion : regionBranchOp->getRegions()) { + Operation *terminator = region.front().getTerminator(); + terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip()); + } + + // Do (3) and (4). + dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip()); +} + +struct RemoveNonLiveValues : public impl::RemoveNonLiveValuesBase { + void runOnOperation() override; +}; +} // namespace + +void RemoveNonLiveValues::runOnOperation() { + auto &la = getAnalysis(); + Operation *irOp = getOperation(); + + // The removal of non-live values is performed iff there are no branch ops, + // all symbol ops present in the IR are function-like, and all symbol user ops + // present in the IR are call-like. + WalkResult acceptableIR = + irOp->walk([&](Operation *op) { + if (isa(op) || + (isa(op) && !isa(op)) || + (isa(op) && !isa(op))) { + op->emitWarning() + << "Unacceptable IR encountered for the optimization pass " + "`remove-non-live-values`. Pass won't execute.\n"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (acceptableIR.wasInterrupted()) + return; + + irOp->walk([&](Operation *op) { + if (auto funcOp = dyn_cast(op)) { + cleanFuncOp(funcOp, irOp, la); + } else if (auto regionBranchOp = dyn_cast(op)) { + cleanRegionBranchOp(regionBranchOp, la); + } else if (op->hasTrait()) { + // Nothing to do because this terminator is associated with either a + // function op or a region branch op and gets cleaned when these ops are + // cleaned. + } else if (isa(op)) { + // Nothing to do because this terminator is associated with a region + // branch op and gets cleaned when the latter is cleaned. + } else if (isa(op)) { + // Nothing to do because this op is associated with a function op and gets + // cleaned when the latter is cleaned. + } else { + cleanSimpleOp(op, la); + } + }); +} + +std::unique_ptr mlir::createRemoveNonLiveValuesPass() { + return std::make_unique(); +} diff --git a/mlir/test/Transforms/remove-non-live-values.mlir b/mlir/test/Transforms/remove-non-live-values.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/remove-non-live-values.mlir @@ -0,0 +1,294 @@ +// RUN: mlir-opt %s -remove-non-live-values -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: module @dont_touch_unacceptable_ir { +// CHECK-LABEL: func.func @has_cleanable_simple_op(%arg0: i32) { +// CHECK-NEXT: %0 = arith.addi %arg0, %arg0 : i32 +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: } +// +// The IR remains untouched because of the presence of a non-function-like +// symbol op (module @dont_touch_unacceptable_ir). +// expected-warning @+1 {{Unacceptable IR encountered for the optimization pass `remove-non-live-values`. Pass won't execute.}} +module @dont_touch_unacceptable_ir { + func.func @has_cleanable_simple_op(%arg0 : i32) { + %0 = arith.addi %arg0, %arg0 : i32 + return + } +} + +// ----- + +// CHECK-LABEL: func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: cf.cond_br %arg0, ^bb1(%c0_i32 : i32), ^bb2(%c0_i32 : i32) +// CHECK-NEXT: ^bb1(%0: i32): +// CHECK-NEXT: cf.br ^bb3 +// CHECK-NEXT: ^bb2(%1: i32): +// CHECK-NEXT: cf.br ^bb3 +// CHECK-NEXT: ^bb3: +// CHECK-NEXT: return +// CHECK-NEXT: } +// +// The IR remains untouched because of the presence of a branch op `cf.cond_br`. +func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) { + %c0_i32 = arith.constant 0 : i32 + // expected-warning @+1 {{Unacceptable IR encountered for the optimization pass `remove-non-live-values`. Pass won't execute.}} + cf.cond_br %arg0, ^bb1(%c0_i32 : i32), ^bb2(%c0_i32 : i32) +^bb1(%0 : i32): + cf.br ^bb3 +^bb2(%1 : i32): + cf.br ^bb3 +^bb3: + return +} + +// ----- + +// CHECK-LABEL: func.func @clean_simple_op(%arg0: i32) { +// CHECK-NEXT: return +// CHECK-NEXT: } +// +// arith.addi (a simple op) gets removed but %arg0 doesn't get removed from the +// function because the signature of a public function is always left untouched. +func.func @clean_simple_op(%arg0 : i32) { + %0 = arith.addi %arg0, %arg0 : i32 + return +} + +// ----- + +// CHECK-LABEL: func.func private @f() { +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-LABEL: func.func @clean_func_op_remove_argument_and_return_value(%arg0: i32) { +// CHECK-NEXT: call @f() : () -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func private @f(%arg0: i32) -> (i32) { + return %arg0 : i32 +} +func.func @clean_func_op_remove_argument_and_return_value(%arg0 : i32) { + %0 = func.call @f(%arg0) : (i32) -> (i32) + return +} + +// ----- + +// CHECK-LABEL: func.func private @f() -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: return %c0_i32 : i32 +// CHECK-NEXT: } +// CHECK-LABEL: func.func @clean_func_op_remove_arguments(%arg0: memref, %arg1: i32) -> (i32, memref) { +// CHECK-NEXT: %0 = call @f() : () -> i32 +// CHECK-NEXT: return %0, %arg0 : i32, memref +// CHECK-NEXT: } +// +// %arg0 is not live in @f because it is never used. %arg1 is not live in @f +// because its user `arith.addi` doesn't have any uses and the %0#1 value that +// it is forwarded to also doesn't have any uses. +func.func private @f(%arg0 : memref, %arg1 : i32) -> (i32, i32) { + %c0_i32 = arith.constant 0 : i32 + %0 = arith.addi %arg1, %arg1 : i32 + return %c0_i32, %arg1 : i32, i32 +} +func.func @clean_func_op_remove_arguments(%arg0 : memref, %arg1 : i32) -> (i32, memref) { + %0:2 = func.call @f(%arg0, %arg1) : (memref, i32) -> (i32, i32) + return %0#0, %arg0 : i32, memref +} + +// ----- + +// CHECK-LABEL: func.func private @f(%arg0: memref) -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: memref.store %c0_i32, %arg0[] : memref +// CHECK-NEXT: return %c0_i32 : i32 +// CHECK-NEXT: } +// CHECK-LABEL: func.func @clean_func_op_remove_return_values(%arg0: memref) -> i32 { +// CHECK-NEXT: %0 = call @f(%arg0) : (memref) -> i32 +// CHECK-NEXT: %1 = call @f(%arg0) : (memref) -> i32 +// CHECK-NEXT: return %0 : i32 +// CHECK-NEXT: } +// +// Even though %4#0 was not live, the first return value of @f isn't removed +// because %0#0 was live (liveness is checked across all callers). +// +// Also, the second return value of @f is removed despite %c0_i32 being live +// because neither %0#1 nor %4#1 were live (removal doesn't depend on the +// liveness of the operand itself but on the liveness of where it is forwarded). +func.func private @f(%arg0 : memref) -> (i32, i32) { + %c0_i32 = arith.constant 0 : i32 + memref.store %c0_i32, %arg0[] : memref + return %c0_i32, %c0_i32 : i32, i32 +} +func.func @clean_func_op_remove_return_values(%arg0 : memref) -> (i32) { + %0:2 = func.call @f(%arg0) : (memref) -> (i32, i32) + %1 = arith.addi %0#0, %0#1 : i32 + %2 = arith.addi %1, %0#0 : i32 + %3 = arith.muli %2, %0#1 : i32 + %4:2 = func.call @f(%arg0) : (memref) -> (i32, i32) + return %0#0 : i32 +} + +// ----- + +// CHECK-LABEL: func.func private @f() -> (i32, i32) { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: return %c0_i32, %c0_i32 : i32, i32 +// CHECK-NEXT: } +// CHECK-LABEL: func.func @clean_func_op_dont_remove_return_values() -> (i32, i32) { +// CHECK-NEXT: %0:2 = call @f() : () -> (i32, i32) +// CHECK-NEXT: %1:2 = call @f() : () -> (i32, i32) +// CHECK-NEXT: return %0#0, %1#1 : i32, i32 +// CHECK-NEXT: } +// +// None of the return values of @f can be removed because the first one is +// forwarded to a live value %0#0 and the second one is forwarded to a live +// value %1#1. +func.func private @f() -> (i32, i32) { + %c0_i32 = arith.constant 0 : i32 + return %c0_i32, %c0_i32 : i32, i32 +} +func.func @clean_func_op_dont_remove_return_values() -> (i32, i32) { + %0:2 = func.call @f() : () -> (i32, i32) + %1:2 = func.call @f() : () -> (i32, i32) + return %0#0, %1#1 : i32, i32 +} + +// ----- + +// CHECK-LABEL: func.func @clean_region_branch_op_erase_it(%arg0: i32, %arg1: i1) { +// CHECK-NEXT: return +// CHECK-NEXT: } +// +// The scf.while op has no memory effects and none of its results are live. +func.func @clean_region_branch_op_erase_it(%arg0 : i32, %arg1 : i1) { + %0 = scf.while (%arg2 = %arg0) : (i32) -> (i32) { + %1 = arith.muli %arg2, %arg2 : i32 + scf.condition(%arg1) %arg2 : i32 + } do { + ^bb0(%arg2: i32): + %2 = arith.muli %arg2, %arg2 : i32 + scf.yield %2 : i32 + } + %3 = arith.addi %0, %0 : i32 + %4 = arith.muli %0, %3 : i32 + return +} + +// ----- + +// CHECK-LABEL: func.func @clean_region_branch_op_keep_results_and_second_region_arg_and_remove_first_region_arg(%arg0: i1) -> i32 { +// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %0:2 = scf.while (%arg1 = %c1_i32) : (i32) -> (i32, i32) { +// CHECK-NEXT: scf.condition(%arg0) %arg1, %arg1 : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg1: i32, %arg2: i32): +// CHECK-NEXT: scf.yield %arg2 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0#0 : i32 +// CHECK-NEXT: } +// +// Values kept: +// (1) %0#1 is not live. Yet, it is kept because %arg2 (the second instance in +// `scf.condition`) forwards to it and this instance of %arg2 has to be kept. +// The second instance of %arg2 in `scf.condition` has to be kept because it +// forwards to %arg4 which is live. +// +// (2) %arg3 is not live. Yet, it is kept because %arg2 (the first instance) +// forwards to it and this instance of %arg2 has to be kept. +// The first instance of %arg2 in `scf.condition` has to be kept because it +// forwards to %0#0 which is live. +// +// Values not kept: +// (1) %c0_i32 is not kept as an operand of `scf.while` because it only +// forwards to %arg1, which is not kept. %arg1 is not kept because only %c0_i32 +// and %arg4 forward to it, neither of them forward anywhere else and %arg1 is +// not live. Thus, %arg4 is also not kept (the first instance) in the +// `scf.yield` op. +func.func @clean_region_branch_op_keep_results_and_second_region_arg_and_remove_first_region_arg(%arg0: i1) -> (i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0:2 = scf.while (%arg1 = %c0_i32, %arg2 = %c1_i32) : (i32, i32) -> (i32, i32) { + %1 = arith.addi %arg1, %arg2 : i32 + scf.condition(%arg0) %arg2, %arg2 : i32, i32 + } do { + ^bb0(%arg3: i32, %arg4: i32): + %2 = arith.addi %arg3, %arg3 : i32 + scf.yield %arg4, %arg4 : i32, i32 + } + return %0#0 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @clean_region_branch_op_remove_results_and_second_region_arg_and_keep_first_region_arg(%arg0: i1) -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %0:2 = scf.while (%arg1 = %c0_i32, %arg2 = %c1_i32) : (i32, i32) -> (i32, i32) { +// CHECK-NEXT: scf.condition(%arg0) %arg2, %arg1 : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg1: i32, %arg2: i32): +// CHECK-NEXT: scf.yield %arg1, %arg2 : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0#0 : i32 +// CHECK-NEXT: } +// +// Values kept: +// (1) %0#0 is kept because it is live. +// +// (2) %0#1 is not live. Yet, it is kept because %arg1 (the first instance in +// `scf.condition`) forwards to it and this instance of %arg1 has to be kept. +// The first instance of %arg1 in `scf.condition` has to be kept because it +// forwards to %arg4, which forwards to %arg2, which forwards to %0#0, which is +// live. +// +// Values not kept: +// (1) %0#2 is not kept because the second instance of %arg1 in `scf.condition` +// forwards to it, which forwards to only %0#2 and %arg5, where both these are +// not live and have no other value forwarding to them. +// +// (2) %0#3 is not kept because the third instance of %arg1 in `scf.condition` +// forwards to it, which forwards to only %0#3 and %arg6, where both these are +// not live and have no other value forwarding to them. +func.func @clean_region_branch_op_remove_results_and_second_region_arg_and_keep_first_region_arg(%arg0: i1) -> (i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0:4 = scf.while (%arg1 = %c0_i32, %arg2 = %c1_i32) : (i32, i32) -> (i32, i32, i32, i32) { + scf.condition(%arg0) %arg2, %arg1, %arg1, %arg1 : i32, i32, i32, i32 + } do { + ^bb0(%arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32): + scf.yield %arg3, %arg4 : i32, i32 + } + return %0#0 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @clean_region_branch_op_remove_result(%arg0: index, %arg1: memref) { +// CHECK-NEXT: scf.index_switch %arg0 +// CHECK-NEXT: case 1 { +// CHECK-NEXT: %c10_i32 = arith.constant 10 : i32 +// CHECK-NEXT: memref.store %c10_i32, %arg1[] : memref +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: default { +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +// +// The op isn't erased because it has memory effects but its unnecessary result +// is removed. +func.func @clean_region_branch_op_remove_result(%cond : index, %arg0 : memref) { + %1 = scf.index_switch %cond -> i32 + case 1 { + %c10_i32 = arith.constant 10 : i32 + memref.store %c10_i32, %arg0[] : memref + scf.yield %c10_i32 : i32 + } + default { + %c11_i32 = arith.constant 11 : i32 + scf.yield %c11_i32 : i32 + } + return +}