diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h @@ -0,0 +1,76 @@ +//===- LivenessAnalysis.h - Liveness analysis -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements liveness analysis using the sparse backward data-flow +// analysis framework. Theoretically, liveness analysis assigns liveness to each +// (value, program point) pair in the program and it is thus a dense analysis. +// However, since values are immutable in MLIR, a sparse analysis, which will +// assign liveness to each value in the program, suffices here. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H +#define MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include + +namespace mlir { +namespace dataflow { + +//===----------------------------------------------------------------------===// +// LivenessAnalysis +//===----------------------------------------------------------------------===// + +/// This lattice represents, for a given value, whether or not it is "live". A +/// value is considered "live" iff it is being written to memory using a +/// `memref.store` operation or is needed to compute a value that is written to +/// memory using a `memref.store` operation. +/// TODO(srisrivastava): Enhance the definition of "live" in this analysis to +/// make it more accurate. Currently some values will be marked "not live" which +/// are theoretically live. +struct Liveness : public AbstractSparseLattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Liveness) + using AbstractSparseLattice::AbstractSparseLattice; + + void print(raw_ostream &os) const override; + + ChangeResult markLive(); + + ChangeResult meet(const AbstractSparseLattice &other) override; + + // At the beginning of the analysis, everything is marked "not live" and as + // the analysis progresses, values are marked "live" if they are found to be + // live. + bool isLive = false; +}; + +/// An analysis that, by going backwards along the dataflow graph, annotates +/// each value with a boolean storing true iff it is "live". +class LivenessAnalysis : public SparseBackwardDataFlowAnalysis { +public: + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + + /// Flow the liveness backward starting from the `results` of the `op`. + /// `operands` here are the operands of `op`. + void backwardFlowLivenessFromResults(Operation *op, + ArrayRef operands, + ArrayRef results); + + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + void visitBranchOperand(OpOperand &operand) override; + + void setToExitState(Liveness *lattice) override; +}; + +} // end namespace dataflow +} // end namespace mlir + +#endif // MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -13,6 +13,7 @@ DataFlow/DeadCodeAnalysis.cpp DataFlow/DenseAnalysis.cpp DataFlow/IntegerRangeAnalysis.cpp + DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp ) @@ -34,6 +35,7 @@ DataFlow/DeadCodeAnalysis.cpp DataFlow/DenseAnalysis.cpp DataFlow/IntegerRangeAnalysis.cpp + DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -0,0 +1,99 @@ +//===- LivenessAnalysis.cpp - Liveness analysis ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/LivenessAnalysis.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include + +using namespace mlir; +using namespace mlir::dataflow; + +void Liveness::print(raw_ostream &os) const { + os << (isLive ? "live" : "not live"); +} + +ChangeResult Liveness::markLive() { + bool wasLive = this->isLive; + this->isLive = true; + return wasLive ? ChangeResult::NoChange : ChangeResult::Change; +} + +ChangeResult Liveness::meet(const AbstractSparseLattice &other) { + const auto *otherLiveness = reinterpret_cast(&other); + return otherLiveness->isLive ? markLive() : ChangeResult::NoChange; +} + +//===----------------------------------------------------------------------===// +// LivenessAnalysis +//===----------------------------------------------------------------------===// + +void LivenessAnalysis::backwardFlowLivenessFromResults( + Operation *op, ArrayRef operands, + ArrayRef results) { + bool foundLiveResult = false; + for (const Liveness *r : results) { + if (r->isLive && !foundLiveResult) { + // By default, every result of an op depends on all its operands. Thus, if + // any result is live, each operand is live. + for (Liveness *operand : operands) + meet(operand, *r); + + // TODO(srisrivastava): Enhance this backward flow of liveness. One + // potential enhancement: If this op exists in a block which is in a + // region of a region-based control flow op, then mark the non-forwarded + // operands of that op as "live". + + foundLiveResult = true; + } + addDependency(const_cast(r), op); + } + return; +} + +void LivenessAnalysis::visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) { + // TODO(srisrivastava): Enhance this base case of liveness analysis to make it + // more accurate. + if (auto store = dyn_cast(op)) { + propagateIfChanged(operands[0], operands[0]->markLive()); + return; + } + + backwardFlowLivenessFromResults(op, operands, results); + return; +} + +void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { + // The lattices of the non-forwarded branch operands don't get updated like + // the forwarded branch operands or the non-branch operands. Thus they need + // to be handled separately. This is where we handle them. The liveness flows + // backward (or, in other words, the lattices get updated) in such operands by + // visiting their corresponding branch op (with all its operands). + + Operation *branchOp = operand.getOwner(); + + SmallVector operandsLiveness; + for (const Value operand : branchOp->getOperands()) { + operandsLiveness.push_back(getLatticeElement(operand)); + } + + SmallVector resultsLiveness; + for (const Value result : branchOp->getResults()) { + resultsLiveness.push_back(getLatticeElement(result)); + } + + backwardFlowLivenessFromResults(branchOp, operandsLiveness, resultsLiveness); + return; +} + +void LivenessAnalysis::setToExitState(Liveness *lattice) { + // Unsure about this but seems like there is nothing to do here, for computing + // liveness. +} diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir @@ -0,0 +1,166 @@ +// RUN: mlir-opt -split-input-file -test-liveness-analysis %s 2>&1 | FileCheck %s + +// This is live because it is stored in memory if the `if` block executes. +// CHECK-LABEL: test_tag: c0: +// CHECK-NEXT: result #0: live + +// This is not live because it is neither stored in memory nor used to compute +// such a value. +// CHECK-LABEL: test_tag: c1: +// CHECK-NEXT: result #0: not live + +// This is live because it is stored in memory if the `else` block executes. +// CHECK-LABEL: test_tag: c2: +// CHECK-NEXT: result #0: live + +// These are live because they are used to decide whether the `if` block executes +// or the `else` one, which in turn decides the value stored in memory. +// Note that if `visitBranchOperand()` was left empty, they would have been +// incorrectly marked as "not live". +// CHECK-LABEL: test_tag: condition0: +// CHECK-NEXT: operand #0: live +// CHECK-NEXT: operand #1: live +// CHECK-NEXT: result #0: live +module { + func.func @test_simple_and_if(%arg0: memref, %arg1: memref, %arg2: i1) { + %c0_i32 = arith.constant {tag = "c0"} 0 : i32 + %c1_i32 = arith.constant {tag = "c1"} 1 : i32 + %c2_i32 = arith.constant {tag = "c2"} 2 : i32 + %0 = arith.addi %arg2, %arg2 {tag = "condition0"} : i1 + %1 = scf.if %0 -> (i32) { + scf.yield %c0_i32 : i32 + } else { + scf.yield %c2_i32 : i32 + } + memref.store %1, %arg0[] : memref + return + } +} + +// ----- + +// zero, ten, and one are live because they are used to decide the number of +// times the `for` loop executes, which in turn decides the value stored in +// memory. +// Note that if `visitBranchOperand()` was left empty, they would have been +// incorrectly marked as "not live". +// CHECK-LABEL: test_tag: zero: +// CHECK-NEXT: result #0: live +// CHECK-LABEL: test_tag: ten: +// CHECK-NEXT: result #0: live +// CHECK-LABEL: test_tag: one: +// CHECK-NEXT: result #0: live +// CHECK-LABEL: test_tag: x: +// CHECK-NEXT: result #0: live +module { + func.func @test_for(%arg0: memref) { + %c0 = arith.constant {tag = "zero"} 0 : index + %c10 = arith.constant {tag = "ten"} 10 : index + %c1 = arith.constant {tag = "one"} 1 : index + %x = arith.constant {tag = "x"} 0 : i32 + %0 = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %x) -> (i32) { + %1 = arith.addi %x, %x : i32 + scf.yield %1 : i32 + } + memref.store %0, %arg0[] : memref + return + } +} + +// ----- + +// This is live because it is used to decide which switch case executes, which +// in turn decides the value stored in memory. +// Note that if `visitBranchOperand()` was left empty, it would have been +// incorrectly marked as "not live". +// CHECK-LABEL: test_tag: switch: +// CHECK-NEXT: operand #0: live +module { + func.func @test_scf_switch(%arg0: index, %arg1: memref) { + %0 = scf.index_switch %arg0 {tag = "switch"} -> i32 + case 1 { + %c10_i32 = arith.constant 10 : i32 + scf.yield %c10_i32 : i32 + } + case 2 { + %c20_i32 = arith.constant 20 : i32 + scf.yield %c20_i32 : i32 + } + default { + %c30_i32 = arith.constant 30 : i32 + scf.yield %c30_i32 : i32 + } + memref.store %0, %arg1[] : memref + return + } +} + +// ----- + +// The branch operand is incorrectly marked "not live" because, for some reason +// unclear to me yet, it is not visited by the `visitBranchOperand()` function. +// CHECK-LABEL: test_tag: br: +// CHECK-NEXT: operand #0: not live +// CHECK-NEXT: operand #1: live +// CHECK-NEXT: operand #2: live +module { + func.func @test_blocks(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: i1) { + %c0_i32 = arith.constant 0 : i32 + cf.cond_br %arg3, ^bb1(%c0_i32 : i32), ^bb2(%c0_i32 : i32) {tag = "br"} + ^bb1(%0 : i32): + memref.store %0, %arg0[] : memref + cf.br ^bb3 + ^bb2(%1 : i32): + memref.store %1, %arg1[] : memref + cf.br ^bb3 + ^bb3: + return + } +} + +// ----- + +// The branch operand is incorrectly marked "not live" because, for some reason +// unclear to me yet, it is not visited by the `visitBranchOperand()` function. +// CHECK-LABEL: test_tag: flag: +// CHECK-NEXT: operand #0: not live +module { + func.func @test_switch(%arg0: i32, %arg1: memref, %arg2: memref) { + %c0_i32 = arith.constant 0 : i32 + cf.switch %arg0 : i32, [ + default: ^bb1, + 42: ^bb2 + ] {tag = "flag"} + ^bb1: + memref.store %c0_i32, %arg1[] : memref + cf.br ^bb3 + ^bb2: + memref.store %c0_i32, %arg2[] : memref + cf.br ^bb3 + ^bb3: + return + } +} + +// ----- + +// The branch operand is incorrectly marked "not live" because, for some reason +// unclear to me yet, it is not visited by the `visitBranchOperand()` function. +// CHECK-LABEL: test_tag: condition: +// CHECK-NEXT: operand #0: not live +// CHECK-NEXT: operand #1: live +module { + func.func @test_condition(%arg0: memref, %arg1: i32, %arg2: i1) { + %c0_i32 = arith.constant 0 : i32 + %0 = scf.while (%arg3 = %c0_i32) : (i32) -> (i32) { + memref.store %arg3, %arg0[] : memref + scf.condition(%arg2) {tag = "condition"} %arg3 : i32 + } do { + ^bb0(%arg3: i32): + memref.store %arg3, %arg0[] : memref + scf.yield %arg3 : i32 + } + memref.store %0, %arg0[] : memref + return + } +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -14,6 +14,7 @@ DataFlow/TestDeadCodeAnalysis.cpp DataFlow/TestDenseDataFlowAnalysis.cpp DataFlow/TestBackwardDataFlowAnalysis.cpp + DataFlow/TestLivenessAnalysis.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp @@ -0,0 +1,68 @@ +//===- TestLivenessAnalysis.cpp - Test liveness analysis ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/LivenessAnalysis.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::dataflow; + +namespace { +struct TestLivenessAnalysisPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLivenessAnalysisPass) + + StringRef getArgument() const override { return "test-liveness-analysis"; } + + void runOnOperation() override { + Operation *op = getOperation(); + + SymbolTableCollection symbolTable; + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + raw_ostream &os = llvm::outs(); + op->walk([&](Operation *op) { + auto tag = op->getAttrOfType("tag"); + if (!tag) + return; + os << "test_tag: " << tag.getValue() << ":\n"; + for (auto [index, operand] : llvm::enumerate(op->getOperands())) { + const Liveness *liveness = solver.lookupState(operand); + assert(liveness && "expected a sparse lattice"); + os << " operand #" << index << ": "; + liveness->print(os); + os << "\n"; + } + for (auto [index, operand] : llvm::enumerate(op->getResults())) { + const Liveness *liveness = solver.lookupState(operand); + assert(liveness && "expected a sparse lattice"); + os << " result #" << index << ": "; + liveness->print(os); + os << "\n"; + } + }); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestLivenessAnalysisPass() { + PassRegistration(); +} +} // end namespace test +} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -101,6 +101,7 @@ void registerTestLinalgElementwiseFusion(); void registerTestLinalgGreedyFusion(); void registerTestLinalgTransforms(); +void registerTestLivenessAnalysisPass(); void registerTestLivenessPass(); void registerTestLoopFusion(); void registerTestCFGLoopInfoPass(); @@ -218,6 +219,7 @@ mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgGreedyFusion(); mlir::test::registerTestLinalgTransforms(); + mlir::test::registerTestLivenessAnalysisPass(); mlir::test::registerTestLivenessPass(); mlir::test::registerTestLoopFusion(); mlir::test::registerTestCFGLoopInfoPass();