diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -16,6 +16,7 @@ #define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/SmallPtrSet.h" @@ -39,7 +40,15 @@ /// Join the information contained in 'rhs' into this lattice. Returns /// if the value of the lattice changed. - virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0; + virtual ChangeResult join(const AbstractSparseLattice &rhs) { + return ChangeResult::NoChange; + } + + /// Meet (intersect) the information in this lattice with 'rhs'. Returns + /// if the value of the lattice changed. + virtual ChangeResult meet(const AbstractSparseLattice &rhs) { + return ChangeResult::NoChange; + } /// When the lattice gets updated, propagate an update to users of the value /// using its use-def chain to subscribed analyses. @@ -86,14 +95,18 @@ return const_cast *>(this)->getValue(); } + using LatticeT = Lattice; + /// Join the information contained in the 'rhs' lattice into this /// lattice. Returns if the state of the current lattice changed. ChangeResult join(const AbstractSparseLattice &rhs) override { - const Lattice &rhsLattice = - static_cast &>(rhs); + return join(static_cast(rhs).getValue()); + } - // Join the rhs value into this lattice. - return join(rhsLattice.getValue()); + /// Meet (intersect) the information contained in the 'rhs' lattice with + /// this lattice. Returns if the state of the current lattice changed. + ChangeResult meet(const AbstractSparseLattice &rhs) override { + return meet(static_cast(rhs).getValue()); } /// Join the information contained in the 'rhs' value into this @@ -114,6 +127,38 @@ return ChangeResult::Change; } + /// Trait to check if `T` provides a `meet` method. Needed since for forward + /// analysis, lattices will only have a `join`, no `meet`, but we want to use + /// the same `Lattice` class for both directions. + template + using has_meet = decltype(std::declval().meet()); + template + using lattice_has_meet = llvm::is_detected; + + /// Meet (intersect) the information contained in the 'rhs' value with this + /// lattice. Returns if the state of the current lattice changed. If the + /// lattice elements don't have a `meet` method, this is a no-op (see below.) + template ::value>> + ChangeResult meet(const VT &rhs) { + ValueT newValue = ValueT::meet(value, rhs); + assert(ValueT::meet(newValue, value) == newValue && + "expected `meet` to be monotonic"); + assert(ValueT::meet(newValue, rhs) == newValue && + "expected `meet` to be monotonic"); + + // Update the current optimistic value if something changed. + if (newValue == value) + return ChangeResult::NoChange; + + value = newValue; + return ChangeResult::Change; + } + + template + ChangeResult meet(const VT &rhs) { + return ChangeResult::NoChange; + } + /// Print the lattice element. void print(raw_ostream &os) const override { value.print(os); } @@ -289,6 +334,135 @@ } }; +//===----------------------------------------------------------------------===// +// AbstractSparseBackwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// Base class for sparse (backward) data-flow analyses. Similar to +/// AbstractSparseDataFlowAnalysis, but walks bottom to top. +class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis { +public: + /// Initialize the analysis by visiting the operation and everything nested + /// under it. + LogicalResult initialize(Operation *top) override; + + /// Visit a program point. If this is a call operation or an operation with + /// block or region control-flow, then operand lattices are set accordingly. + /// Otherwise, invokes the operation transfer function (`visitOperationImpl`). + LogicalResult visit(ProgramPoint point) override; + +protected: + explicit AbstractSparseBackwardDataFlowAnalysis( + DataFlowSolver &solver, SymbolTableCollection &symbolTable); + + /// The operation transfer function. Given the result lattices, this + /// function is expected to set the operand lattices. + virtual void visitOperationImpl( + Operation *op, ArrayRef operandLattices, + ArrayRef resultLattices) = 0; + + // Visit operands on branch instructions that are not forwarded + virtual void visitBranchOperand(OpOperand &operand) = 0; + + /// Set the given lattice element(s) at control flow exit point(s). + virtual void setToExitState(AbstractSparseLattice *lattice) = 0; + + /// Set the given lattice element(s) at control flow exit point(s). + void setAllToExitStates(ArrayRef lattices); + + /// Get the lattice element for a value. + virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; + + /// Get the lattice elements for a range of values. + SmallVector getLatticeElements(ValueRange values); + + /// Join the lattice element and propagate and update if it changed. + void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); + +private: + /// Recursively initialize the analysis on nested operations and blocks. + LogicalResult initializeRecursively(Operation *op); + + /// Visit an operation. If this is a call operation or an operation with + /// region control-flow, then its operand lattices are set accordingly. + /// Otherwise, the operation transfer function is invoked. + void visitOperation(Operation *op); + + /// Visit a block. + void visitBlock(Block *block); + + /// Visit an op with regions (like e.g. `scf.while`) + void visitRegionSuccessors(RegionBranchOpInterface branch, + ArrayRef operands); + + /// Get the lattice element for a value, and also set up + /// dependencies so that the analysis on the given ProgramPoint is re-invoked + /// if the value changes. + const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point, + Value value); + + /// Get the lattice elements for a range of values, and also set up + /// dependencies so that the analysis on the given ProgramPoint is re-invoked + /// if any of the values change. + SmallVector + getLatticeElementsFor(ProgramPoint point, ValueRange values); + + SymbolTableCollection &symbolTable; +}; + +//===----------------------------------------------------------------------===// +// SparseBackwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// A sparse (backward) data-flow analysis for propagating SSA value lattices +/// backwards across the IR by implementing transfer functions for operations. +/// +/// `StateT` is expected to be a subclass of `AbstractSparseLattice`. +template +class SparseBackwardDataFlowAnalysis + : public AbstractSparseBackwardDataFlowAnalysis { +public: + explicit SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, + SymbolTableCollection &symbolTable) + : AbstractSparseBackwardDataFlowAnalysis(solver, symbolTable) {} + + /// Visit an operation with the lattices of its results. This function is + /// expected to set the lattices of the operation's operands. + virtual void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) = 0; + +protected: + /// Get the lattice element for a value. + StateT *getLatticeElement(Value value) override { + return getOrCreate(value); + } + + /// Set the given lattice element(s) at control flow exit point(s). + virtual void setToExitState(StateT *lattice) = 0; + void setToExitState(AbstractSparseLattice *lattice) override { + return setToExitState(reinterpret_cast(lattice)); + } + void setAllToExitStates(ArrayRef lattices) { + AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates( + {reinterpret_cast(lattices.begin()), + lattices.size()}); + } + +private: + /// Type-erased wrappers that convert the abstract lattice operands to derived + /// lattices and invoke the virtual hooks operating on the derived lattices. + void visitOperationImpl( + Operation *op, ArrayRef operandLattices, + ArrayRef resultLattices) override { + visitOperation( + op, + {reinterpret_cast(operandLattices.begin()), + operandLattices.size()}, + {reinterpret_cast(resultLattices.begin()), + resultLattices.size()}); + } +}; + } // end namespace dataflow } // end namespace mlir diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -41,7 +41,7 @@ if (region.empty()) continue; for (Value argument : region.front().getArguments()) - setAllToEntryStates(getLatticeElement(argument)); + setToEntryState(getLatticeElement(argument)); } return initializeRecursively(top); @@ -280,3 +280,271 @@ const AbstractSparseLattice &rhs) { propagateIfChanged(lhs, lhs->join(rhs)); } + +//===----------------------------------------------------------------------===// +// AbstractSparseBackwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis( + DataFlowSolver &solver, SymbolTableCollection &symbolTable) + : DataFlowAnalysis(solver), symbolTable(symbolTable) { + registerPointKind(); +} + +LogicalResult +AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) { + return initializeRecursively(top); +} + +LogicalResult +AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) { + visitOperation(op); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + getOrCreate(&block)->blockContentSubscribe(this); + // Initialize ops in reverse order, so we can do as much initial + // propagation as possible without having to go through the + // solver queue. + for (auto it = block.rbegin(); it != block.rend(); it++) + if (failed(initializeRecursively(&*it))) + return failure(); + } + } + return success(); +} + +LogicalResult +AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) { + if (Operation *op = point.dyn_cast()) + visitOperation(op); + else if (Block *block = point.dyn_cast()) + // For backward dataflow, we don't have to do any work for the blocks + // themselves. CFG edges between blocks are processed by the BranchOp + // logic in `visitOperation`, and entry blocks for functions are tied + // to the CallOp arguments by visitOperation. + return success(); + else + return failure(); + return success(); +} + +SmallVector +AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) { + SmallVector resultLattices; + resultLattices.reserve(values.size()); + for (Value result : values) { + AbstractSparseLattice *resultLattice = getLatticeElement(result); + resultLattices.push_back(resultLattice); + } + return resultLattices; +} + +SmallVector +AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor( + ProgramPoint point, ValueRange values) { + SmallVector resultLattices; + resultLattices.reserve(values.size()); + for (Value result : values) { + const AbstractSparseLattice *resultLattice = + getLatticeElementFor(point, result); + resultLattices.push_back(resultLattice); + } + return resultLattices; +} + +static MutableArrayRef operandsToOpOperands(OperandRange &operands) { + return MutableArrayRef(operands.getBase(), operands.size()); +} + +void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { + // If we're in a dead block, bail out. + if (!getOrCreate(op->getBlock())->isLive()) + return; + + SmallVector operandLattices = + getLatticeElements(op->getOperands()); + SmallVector resultLattices = + getLatticeElementsFor(op, op->getResults()); + + // Block arguments of region branch operations flow back into the operands + // of the parent op + if (auto branch = dyn_cast(op)) { + visitRegionSuccessors(branch, operandLattices); + return; + } + + if (auto branch = dyn_cast(op)) { + // Block arguments of successor blocks flow back into our operands. + + // We remember all operands not forwarded to any block in a BitVector. + // We can't just cut out a range here, since the non-forwarded ops might + // be non-contiguous (if there's more than one successor). + BitVector unaccounted(op->getNumOperands(), true); + + for (auto [index, block] : llvm::enumerate(op->getSuccessors())) { + SuccessorOperands successorOperands = branch.getSuccessorOperands(index); + OperandRange forwarded = successorOperands.getForwardedOperands(); + if (forwarded.size()) { + MutableArrayRef operands = op->getOpOperands().slice( + forwarded.getBeginOperandIndex(), forwarded.size()); + for (OpOperand &operand : operands) { + unaccounted.reset(operand.getOperandNumber()); + if (Optional blockArg = + detail::getBranchSuccessorArgument( + successorOperands, operand.getOperandNumber(), block)) { + meet(getLatticeElement(operand.get()), + *getLatticeElementFor(op, *blockArg)); + } + } + } + } + // Operands not forwarded to successor blocks are typically parameters + // of the branch operation itself (for example the boolean for if/else). + for (int index : unaccounted.set_bits()) { + OpOperand &operand = op->getOpOperand(index); + visitBranchOperand(operand); + } + return; + } + + // For function calls, connect the arguments of the entry blocks + // to the operands of the call op. + if (auto call = dyn_cast(op)) { + Operation *callableOp = call.resolveCallable(&symbolTable); + if (auto callable = dyn_cast_or_null(callableOp)) { + Region *region = callable.getCallableRegion(); + if (!region->empty()) { + Block &block = region->front(); + for (auto [blockArg, operand] : + llvm::zip(block.getArguments(), operandLattices)) { + meet(operand, *getLatticeElementFor(op, blockArg)); + } + } + return; + } + } + + // The block arguments of the branched to region flow back into the + // operands of the yield operation. + if (auto terminator = dyn_cast(op)) { + if (auto branch = dyn_cast(op->getParentOp())) { + SmallVector successors; + SmallVector operands(op->getNumOperands(), nullptr); + branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(), + operands, successors); + // All operands not forwarded to any successor. This set can be + // non-contiguous in the presence of multiple successors. + BitVector unaccounted(op->getNumOperands(), true); + + for (const RegionSuccessor &successor : successors) { + ValueRange inputs = successor.getSuccessorInputs(); + Region *region = successor.getSuccessor(); + OperandRange operands = + region ? terminator.getSuccessorOperands(region->getRegionNumber()) + : terminator.getSuccessorOperands({}); + MutableArrayRef opoperands = operandsToOpOperands(operands); + for (auto [opoperand, input] : llvm::zip(opoperands, inputs)) { + meet(getLatticeElement(opoperand.get()), + *getLatticeElementFor(op, input)); + unaccounted.reset( + const_cast(opoperand).getOperandNumber()); + } + } + // Visit operands of the branch op not forwarded to the next region. + // (Like e.g. the boolean of `scf.conditional`) + for (int index : unaccounted.set_bits()) { + visitBranchOperand(op->getOpOperand(index)); + } + return; + } + } + + // yield-like ops usually don't implement `RegionBranchTerminatorOpInterface`, + // since they behave like a return in the sense that they forward to the + // results of some other (here: the parent) op. + if (op->hasTrait()) { + if (auto branch = dyn_cast(op->getParentOp())) { + OperandRange operands = op->getOperands(); + ResultRange results = op->getParentOp()->getResults(); + assert(results.size() == operands.size() && + "Can't derive arg mapping for yield-like op."); + for (auto [operand, result] : llvm::zip(operands, results)) + meet(getLatticeElement(operand), *getLatticeElementFor(op, result)); + return; + } + + // Going backwards, the operands of the return are derived from the + // results of all CallOps calling this CallableOp. + if (auto callable = dyn_cast(op->getParentOp())) { + const PredecessorState *callsites = + getOrCreateFor(op, callable); + if (callsites->allPredecessorsKnown()) { + for (Operation *call : callsites->getKnownPredecessors()) { + SmallVector callResultLattices = + getLatticeElementsFor(op, call->getResults()); + for (auto [op, result] : + llvm::zip(operandLattices, callResultLattices)) + meet(op, *result); + } + } else { + // If we don't know all the callers, we can't know where the + // returned values go. Note that, in particular, this will trigger + // for the return ops of any public functions. + setAllToExitStates(operandLattices); + } + return; + } + } + + visitOperationImpl(op, operandLattices, resultLattices); +} + +void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( + RegionBranchOpInterface branch, + ArrayRef operandLattices) { + Operation *op = branch.getOperation(); + SmallVector successors; + SmallVector operands(op->getNumOperands(), nullptr); + branch.getSuccessorRegions(/*index=*/{}, operands, successors); + + // All operands not forwarded to any successor. This set can be non-contiguous + // in the presence of multiple successors. + BitVector unaccounted(op->getNumOperands(), true); + + for (RegionSuccessor &successor : successors) { + Region *region = successor.getSuccessor(); + OperandRange operands = + region ? branch.getSuccessorEntryOperands(region->getRegionNumber()) + : branch.getSuccessorEntryOperands({}); + MutableArrayRef opoperands = operandsToOpOperands(operands); + ValueRange inputs = successor.getSuccessorInputs(); + for (auto [operand, input] : llvm::zip(opoperands, inputs)) { + meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input)); + unaccounted.reset(operand.getOperandNumber()); + } + } + // All operands not forwarded to regions are typically parameters of the + // branch operation itself (for example the boolean for if/else). + for (int index : unaccounted.set_bits()) { + visitBranchOperand(op->getOpOperand(index)); + } +} + +const AbstractSparseLattice * +AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, + Value value) { + AbstractSparseLattice *state = getLatticeElement(value); + addDependency(state, point); + return state; +} + +void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates( + ArrayRef lattices) { + for (AbstractSparseLattice *lattice : lattices) + setToExitState(lattice); +} + +void AbstractSparseBackwardDataFlowAnalysis::meet( + AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) { + propagateIfChanged(lhs, lhs->meet(rhs)); +} diff --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir @@ -0,0 +1,248 @@ +// RUN: mlir-opt -split-input-file -test-written-to %s 2>&1 | FileCheck %s + +// CHECK-LABEL: test_tag: constant0 +// CHECK: result #0: [a] +// CHECK-LABEL: test_tag: constant1 +// CHECK: result #0: [b] +func.func @test_two_writes(%m0: memref, %m1: memref) -> (memref, memref) { + %c0 = arith.constant {tag = "constant0"} 0 : i32 + %c1 = arith.constant {tag = "constant1"} 1 : i32 + memref.store %c0, %m0[] {tag_name = "a"} : memref + memref.store %c1, %m1[] {tag_name = "b"} : memref + return %m0, %m1 : memref, memref +} + +// ----- + +// CHECK-LABEL: test_tag: c0 +// CHECK: result #0: [b] +// CHECK-LABEL: test_tag: c1 +// CHECK: result #0: [b] +// CHECK-LABEL: test_tag: condition +// CHECK: result #0: [brancharg0] +// CHECK-LABEL: test_tag: c2 +// CHECK: result #0: [a] +// CHECK-LABEL: test_tag: c3 +// CHECK: result #0: [a] +func.func @test_if(%m0: memref, %m1: memref, %condition: i1) { + %c0 = arith.constant {tag = "c0"} 2 : i32 + %c1 = arith.constant {tag = "c1"} 3 : i32 + %condition2 = arith.addi %condition, %condition {tag = "condition"} : i1 + %0, %1 = scf.if %condition2 -> (i32, i32) { + %c2 = arith.constant {tag = "c2"} 0 : i32 + scf.yield %c2, %c0: i32, i32 + } else { + %c3 = arith.constant {tag = "c3"} 1 : i32 + scf.yield %c3, %c1: i32, i32 + } + memref.store %0, %m0[] {tag_name = "a"} : memref + memref.store %1, %m1[] {tag_name = "b"} : memref + return +} + +// ----- + +// CHECK-LABEL: test_tag: c0 +// CHECK: result #0: [a c] +// CHECK-LABEL: test_tag: c1 +// CHECK: result #0: [b c] +// CHECK-LABEL: test_tag: br +// CHECK: operand #0: [brancharg0] +func.func @test_blocks(%m0: memref, + %m1: memref, + %m2: memref, %cond : i1) { + %0 = arith.constant {tag = "c0"} 0 : i32 + %1 = arith.constant {tag = "c1"} 1 : i32 + cf.cond_br %cond, ^a(%0: i32), ^b(%1: i32) {tag = "br"} +^a(%a0: i32): + memref.store %a0, %m0[] {tag_name = "a"} : memref + cf.br ^c(%a0 : i32) +^b(%b0: i32): + memref.store %b0, %m1[] {tag_name = "b"} : memref + cf.br ^c(%b0 : i32) +^c(%c0 : i32): + memref.store %c0, %m2[] {tag_name = "c"} : memref + return +} + +// ----- + +// CHECK-LABEL: test_tag: two +// CHECK: result #0: [a] +func.func @test_infinite_loop(%m0: memref) { + %0 = arith.constant 0 : i32 + %1 = arith.constant 1 : i32 + %2 = arith.constant {tag = "two"} 2 : i32 + %3 = arith.constant -1 : i32 + cf.br ^loop(%0, %1, %2: i32, i32, i32) +^loop(%a: i32, %b: i32, %c: i32): + memref.store %a, %m0[] {tag_name = "a"} : memref + cf.br ^loop(%b, %c, %3 : i32, i32, i32) +} + +// ----- + +// CHECK-LABEL: test_tag: c0 +// CHECK: result #0: [a b c] +func.func @test_switch(%flag: i32, %m0: memref) { + %0 = arith.constant {tag = "c0"} 0 : i32 + cf.switch %flag : i32, [ + default: ^a(%0 : i32), + 42: ^b(%0 : i32), + 43: ^c(%0 : i32) + ] +^a(%a0: i32): + memref.store %a0, %m0[] {tag_name = "a"} : memref + cf.br ^c(%a0 : i32) +^b(%b0: i32): + memref.store %b0, %m0[] {tag_name = "b"} : memref + cf.br ^c(%b0 : i32) +^c(%c0 : i32): + memref.store %c0, %m0[] {tag_name = "c"} : memref + return +} + +// ----- + +// CHECK-LABEL: test_tag: add +// CHECK: result #0: [a] +func.func @test_caller(%m0: memref, %arg: f32) { + %0 = arith.addf %arg, %arg {tag = "add"} : f32 + %1 = func.call @callee(%0) : (f32) -> f32 + %2 = arith.mulf %1, %1 : f32 + %3 = arith.mulf %2, %2 : f32 + %4 = arith.mulf %3, %3 : f32 + memref.store %4, %m0[] {tag_name = "a"} : memref + return +} + +func.func private @callee(%0 : f32) -> f32 { + %1 = arith.mulf %0, %0 : f32 + %2 = arith.mulf %1, %1 : f32 + func.return %2 : f32 +} + +// ----- + +func.func private @callee(%0 : f32) -> f32 { + %1 = arith.mulf %0, %0 : f32 + func.return %1 : f32 +} + +// CHECK-LABEL: test_tag: sub +// CHECK: result #0: [a] +func.func @test_caller_below_callee(%m0: memref, %arg: f32) { + %0 = arith.subf %arg, %arg {tag = "sub"} : f32 + %1 = func.call @callee(%0) : (f32) -> f32 + memref.store %1, %m0[] {tag_name = "a"} : memref + return +} + +// ----- + +func.func private @callee1(%0 : f32) -> f32 { + %1 = func.call @callee2(%0) : (f32) -> f32 + func.return %1 : f32 +} + +func.func private @callee2(%0 : f32) -> f32 { + %1 = func.call @callee3(%0) : (f32) -> f32 + func.return %1 : f32 +} + +func.func private @callee3(%0 : f32) -> f32 { + func.return %0 : f32 +} + +// CHECK-LABEL: test_tag: mul +// CHECK: result #0: [a] +func.func @test_callchain(%m0: memref, %arg: f32) { + %0 = arith.mulf %arg, %arg {tag = "mul"} : f32 + %1 = func.call @callee1(%0) : (f32) -> f32 + memref.store %1, %m0[] {tag_name = "a"} : memref + return +} + +// ----- + +// CHECK-LABEL: test_tag: zero +// CHECK: result #0: [c] +// CHECK-LABEL: test_tag: init +// CHECK: result #0: [a b] +// CHECK-LABEL: test_tag: condition +// CHECK: operand #0: [brancharg0] +func.func @test_while(%m0: memref, %init : i32, %cond: i1) { + %zero = arith.constant {tag = "zero"} 0 : i32 + %init2 = arith.addi %init, %init {tag = "init"} : i32 + %0, %1 = scf.while (%arg1 = %zero, %arg2 = %init2) : (i32, i32) -> (i32, i32) { + memref.store %arg2, %m0[] {tag_name = "a"} : memref + scf.condition(%cond) {tag = "condition"} %arg1, %arg2 : i32, i32 + } do { + ^bb0(%arg1: i32, %arg2: i32): + memref.store %arg1, %m0[] {tag_name = "c"} : memref + %res = arith.addi %arg2, %arg2 : i32 + scf.yield %arg1, %res: i32, i32 + } + memref.store %1, %m0[] {tag_name = "b"} : memref + return +} + +// ----- + +// CHECK-LABEL: test_tag: zero +// CHECK: result #0: [brancharg0] +// CHECK-LABEL: test_tag: ten +// CHECK: result #0: [brancharg1] +// CHECK-LABEL: test_tag: one +// CHECK: result #0: [brancharg2] +// CHECK-LABEL: test_tag: x +// CHECK: result #0: [a] +func.func @test_for(%m0: memref) { + %zero = arith.constant {tag = "zero"} 0 : index + %ten = arith.constant {tag = "ten"} 10 : index + %one = arith.constant {tag = "one"} 1 : index + %x = arith.constant {tag = "x"} 0 : i32 + %0 = scf.for %i = %zero to %ten step %one iter_args(%ix = %x) -> (i32) { + scf.yield %ix : i32 + } + memref.store %0, %m0[] {tag_name = "a"} : memref + return +} + +// ----- + +// CHECK-LABEL: test_tag: default_a +// CHECK-LABEL: result #0: [a] +// CHECK-LABEL: test_tag: default_b +// CHECK-LABEL: result #0: [b] +// CHECK-LABEL: test_tag: 1a +// CHECK-LABEL: result #0: [a] +// CHECK-LABEL: test_tag: 1b +// CHECK-LABEL: result #0: [b] +// CHECK-LABEL: test_tag: 2a +// CHECK-LABEL: result #0: [a] +// CHECK-LABEL: test_tag: 2b +// CHECK-LABEL: result #0: [b] +// CHECK-LABEL: test_tag: switch +// CHECK-LABEL: operand #0: [brancharg0] +func.func @test_switch(%arg0 : index, %m0: memref) { + %0, %1 = scf.index_switch %arg0 {tag="switch"} -> i32, i32 + case 1 { + %2 = arith.constant {tag="1a"} 10 : i32 + %3 = arith.constant {tag="1b"} 100 : i32 + scf.yield %2, %3 : i32, i32 + } + case 2 { + %4 = arith.constant {tag="2a"} 20 : i32 + %5 = arith.constant {tag="2b"} 200 : i32 + scf.yield %4, %5 : i32, i32 + } + default { + %6 = arith.constant {tag="default_a"} 30 : i32 + %7 = arith.constant {tag="default_b"} 300 : i32 + scf.yield %6, %7 : i32, i32 + } + memref.store %0, %m0[] {tag_name = "a"} : memref + memref.store %1, %m0[] {tag_name = "b"} : memref + return +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -12,6 +12,7 @@ DataFlow/TestDeadCodeAnalysis.cpp DataFlow/TestDenseDataFlowAnalysis.cpp + DataFlow/TestBackwardDataFlowAnalysis.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp @@ -0,0 +1,142 @@ +//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::dataflow; + +namespace { + +/// This lattice represents, for a given value, the set of memory resources that +/// this value, or anything derived from this value, is potentially written to. +struct WrittenTo : public AbstractSparseLattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo) + using AbstractSparseLattice::AbstractSparseLattice; + + void print(raw_ostream &os) const override { + os << "["; + llvm::interleave( + writes, os, [&](const StringAttr &a) { os << a.str(); }, " "); + os << "]"; + } + ChangeResult addWrites(const SetVector &writes) { + int size_before = this->writes.size(); + this->writes.insert(writes.begin(), writes.end()); + int size_after = this->writes.size(); + return size_before == size_after ? ChangeResult::NoChange + : ChangeResult::Change; + } + ChangeResult meet(const AbstractSparseLattice &other) override { + auto rhs = reinterpret_cast(&other); + return addWrites(rhs->writes); + } + + SetVector writes; +}; + +/// An analysis that, by going backwards along the dataflow graph, annotates +/// each value with all the memory resources it (or anything derived from it) +/// is eventually written to. +class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis { +public: + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + void visitBranchOperand(OpOperand &operand) override; + + void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); } +}; + +void WrittenToAnalysis::visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) { + if (auto store = dyn_cast(op)) { + SetVector new_writes; + new_writes.insert(op->getAttrOfType("tag_name")); + propagateIfChanged(operands[0], operands[0]->addWrites(new_writes)); + return; + } else { + // By default, every result of an op depends on every operand. + for (const WrittenTo *r : results) { + for (WrittenTo *operand : operands) { + meet(operand, *r); + } + addDependency(const_cast(r), op); + } + } +} + +void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) { + // Mark branch operands as "brancharg%d", with %d the operand number. + WrittenTo *lattice = getLatticeElement(operand.get()); + SetVector new_writes; + new_writes.insert( + StringAttr::get(operand.getOwner()->getContext(), + "brancharg" + Twine(operand.getOperandNumber()))); + propagateIfChanged(lattice, lattice->addWrites(new_writes)); +} + +} // end anonymous namespace + +namespace { +struct TestWrittenToPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass) + + StringRef getArgument() const override { return "test-written-to"; } + + void runOnOperation() override { + Operation *op = getOperation(); + + SymbolTableCollection symbolTable; + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + raw_ostream &os = llvm::outs(); + op->walk([&](Operation *op) { + auto tag = op->getAttrOfType("tag"); + if (!tag) + return; + os << "test_tag: " << tag.getValue() << ":\n"; + for (auto [index, operand] : llvm::enumerate(op->getOperands())) { + const WrittenTo *writtenTo = solver.lookupState(operand); + assert(writtenTo && "expected a sparse lattice"); + os << " operand #" << index << ": "; + writtenTo->print(os); + os << "\n"; + } + for (auto [index, operand] : llvm::enumerate(op->getResults())) { + const WrittenTo *writtenTo = solver.lookupState(operand); + assert(writtenTo && "expected a sparse lattice"); + os << " result #" << index << ": "; + writtenTo->print(os); + os << "\n"; + } + }); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestWrittenToPass() { PassRegistration(); } +} // end namespace test +} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -118,6 +118,7 @@ void registerTestTopologicalSortAnalysisPass(); void registerTestTransformDialectEraseSchedulePass(); void registerTestTransformDialectInterpreterPass(); +void registerTestWrittenToPass(); void registerTestVectorLowerings(); void registerTestNvgpuLowerings(); } // namespace test @@ -223,6 +224,7 @@ mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings(); mlir::test::registerTestNvgpuLowerings(); + mlir::test::registerTestWrittenToPass(); } #endif