diff --git a/mlir/include/mlir/Analysis/DataFlowAnalysis.h b/mlir/include/mlir/Analysis/DataFlowAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlowAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlowAnalysis.h @@ -250,6 +250,18 @@ ArrayRef operands, SmallVectorImpl &successors) = 0; + /// Given a operation with successor regions, one of those regions, + /// and the lattice elements corresponding to the operation's + /// arguments, compute the latice values for block arguments + /// that are not accounted for by the branching control flow (ex. the + /// bounds of loops). By default, this method marks all such lattice elements + /// as having reached a pessimistic fixpoint. The region in the + /// RegionSuccessor and the operand latice elements are guaranteed to be + /// non-null. + virtual ChangeResult + visitNonControlFlowArguments(Operation *op, const RegionSuccessor ®ion, + ArrayRef operands) = 0; + /// Create a new uninitialized lattice element. An optional value is provided /// which, if valid, should be used to initialize the known conservative state /// of the lattice. @@ -347,6 +359,33 @@ branch.getSuccessorRegions(sourceIndex, constantOperands, successors); } + /// Given a operation with successor regions, one of those regions, + /// and the lattice elements corresponding to the operation's + /// arguments, compute the latice values for block arguments + /// that are not accounted for by the branching control flow (ex. the + /// bounds of loops). By default, this method marks all such lattice elements + /// as having reached a pessimistic fixpoint. The region in the + /// RegionSuccessor and the operand latice elements are guaranteed to be + /// non-null. + virtual ChangeResult + visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, + ArrayRef *> operands) { + ChangeResult ret = ChangeResult::NoChange; + Region *region = successor.getSuccessor(); + ValueRange succArgs = successor.getSuccessorInputs(); + Block *block = ®ion->front(); + Block::BlockArgListType arguments = block->getArguments(); + if (arguments.size() != succArgs.size()) { + unsigned firstArgIdx = + succArgs.empty() ? succArgs.size() + : succArgs[0].cast().getArgNumber(); + ret |= markAllPessimisticFixpoint(arguments.take_front(firstArgIdx)); + ret |= markAllPessimisticFixpoint( + arguments.drop_front(firstArgIdx + succArgs.size())); + } + return ret; + } + private: /// Type-erased wrappers that convert the abstract lattice operands to derived /// lattices and invoke the virtual hooks operating on the derived lattices. @@ -379,6 +418,14 @@ branch, sourceIndex, llvm::makeArrayRef(derivedOperandBase, operands.size()), successors); } + ChangeResult visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef operands) final { + LatticeElement *const *derivedOperandBase = + reinterpret_cast *const *>(operands.data()); + return visitNonControlFlowArguments( + op, region, llvm::makeArrayRef(derivedOperandBase, operands.size())); + } /// Create a new uninitialized lattice element. An optional value is provided, /// which if valid, should be used to initialize the known conservative state diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp --- a/mlir/lib/Analysis/DataFlowAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallPtrSet.h" #include @@ -113,6 +114,7 @@ /// the parent operation results. void visitRegionSuccessors( Operation *parentOp, ArrayRef regionSuccessors, + ArrayRef operandLattices, function_ref)> getInputsForRegion); /// Visit the given terminator operation and compute any necessary lattice @@ -460,7 +462,7 @@ if (successors.empty()) return markAllPessimisticFixpoint(branch, branch->getResults()); return visitRegionSuccessors( - branch, successors, [&](Optional index) { + branch, successors, operandLattices, [&](Optional index) { assert(index && "expected valid region index"); return branch.getSuccessorEntryOperands(*index); }); @@ -468,6 +470,7 @@ void ForwardDataFlowSolver::visitRegionSuccessors( Operation *parentOp, ArrayRef regionSuccessors, + ArrayRef operandLattices, function_ref)> getInputsForRegion) { for (const RegionSuccessor &it : regionSuccessors) { Region *region = it.getSuccessor(); @@ -514,22 +517,17 @@ if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); })) continue; - // Mark any arguments that do not receive inputs as having reached a - // pessimistic fixpoint, we won't be able to discern if they are constant. - // TODO: This isn't exactly ideal. There may be situations in which a - // region operation can provide information for certain results that - // aren't part of the control flow. if (succArgs.size() != arguments.size()) { - if (succArgs.empty()) { - markAllPessimisticFixpoint(arguments); - continue; + if (analysis.visitNonControlFlowArguments( + parentOp, it, operandLattices) == ChangeResult::Change) { + unsigned firstArgIdx = + succArgs.empty() ? succArgs.size() + : succArgs[0].cast().getArgNumber(); + for (Value v : arguments.take_front(firstArgIdx)) + visitUsers(v); + for (Value v : arguments.drop_front(firstArgIdx + succArgs.size())) + visitUsers(v); } - - unsigned firstArgIdx = succArgs[0].cast().getArgNumber(); - markAllPessimisticFixpointAndVisitUsers( - arguments.take_front(firstArgIdx)); - markAllPessimisticFixpointAndVisitUsers( - arguments.drop_front(firstArgIdx + succArgs.size())); } // Update the lattice of arguments that have inputs from the predecessor. @@ -573,12 +571,14 @@ // Try to get "region-like" successor operands if possible in order to // propagate the operand states to the successors. if (isRegionReturnLike(op)) { - return visitRegionSuccessors( - parentOp, regionSuccessors, [&](Optional regionIndex) { - // Determine the individual region successor operands for the given - // region index (if any). - return *getRegionBranchSuccessorOperands(op, regionIndex); - }); + return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices, + [&](Optional regionIndex) { + // Determine the individual region + // successor operands for the given region + // index (if any). + return *getRegionBranchSuccessorOperands( + op, regionIndex); + }); } // If this terminator is not "region-like", conservatively mark all of the diff --git a/mlir/test/Analysis/test-data-flow.mlir b/mlir/test/Analysis/test-data-flow.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-data-flow.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt -test-data-flow --allow-unregistered-dialect %s 2>&1 | FileCheck %s + +// CHECK-LABEL: Testing : "loop-arg-pessimistic" +module attributes {test.name = "loop-arg-pessimistic"} { + func @f() -> index { + // CHECK: Visiting : %{{.*}} = arith.constant 0 + // CHECK-NEXT: Result 0 moved from uninitialized to 1 + %c0 = arith.constant 0 : index + // CHECK: Visiting : %{{.*}} = arith.constant 1 + // CHECK-NEXT: Result 0 moved from uninitialized to 1 + %c1 = arith.constant 1 : index + %0 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %c0) -> index { + // CHECK: Visiting : %{{.*}} = arith.addi %{{.*}}, %{{.*}} + // CHECK-NEXT: Arg 0 : 0 + // CHECK-NEXT: Arg 1 : 1 + // CHECK-NEXT: Result 0 moved from uninitialized to 0 + %10 = arith.addi %arg1, %arg2 : index + scf.yield %10 : index + } + return %0 : index + } +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRTestAnalysis TestAliasAnalysis.cpp TestCallGraph.cpp + TestDataFlow.cpp TestLiveness.cpp TestMatchReduction.cpp TestMemRefBoundCheck.cpp diff --git a/mlir/test/lib/Analysis/TestDataFlow.cpp b/mlir/test/lib/Analysis/TestDataFlow.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/TestDataFlow.cpp @@ -0,0 +1,101 @@ +//===- TestDataFlow.cpp - Test data flow analysis system -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for defining and running a dataflow analysis. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; + +namespace { +struct WasAnalyzed { + bool wasAnalyzed; + WasAnalyzed(bool wasAnalyzed) : wasAnalyzed(wasAnalyzed) {} + + static WasAnalyzed join(const WasAnalyzed &a, const WasAnalyzed &b) { + return a.wasAnalyzed && b.wasAnalyzed; + } + + static WasAnalyzed getPessimisticValueState(MLIRContext *context) { + return false; + } + + static WasAnalyzed getPessimisticValueState(Value v) { + return getPessimisticValueState(v.getContext()); + } + + inline bool operator==(const WasAnalyzed &other) const { + return wasAnalyzed == other.wasAnalyzed; + } +}; + +struct TestAnalysis : public ForwardDataFlowAnalysis { + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + ~TestAnalysis() override = default; + + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) final { + ChangeResult ret = ChangeResult::NoChange; + llvm::errs() << "Visiting : "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + + WasAnalyzed result(true); + for (auto &pair : llvm::enumerate(operands)) { + LatticeElement *elem = pair.value(); + llvm::errs() << "Arg " << pair.index(); + if (!elem->isUninitialized()) { + llvm::errs() << " : " << elem->getValue().wasAnalyzed << "\n"; + result = WasAnalyzed::join(result, elem->getValue()); + } else { + llvm::errs() << " uninitialized\n"; + } + } + for (const auto &pair : llvm::enumerate(op->getResults())) { + LatticeElement &lattice = getLatticeElement(pair.value()); + llvm::errs() << "Result " << pair.index() << " moved from "; + if (lattice.isUninitialized()) + llvm::errs() << "uninitialized"; + else + llvm::errs() << lattice.getValue().wasAnalyzed; + ret |= lattice.join({result}); + llvm::errs() << " to " << lattice.getValue().wasAnalyzed << "\n"; + } + return ret; + } +}; + +struct TestDataFlowPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataFlowPass) + + StringRef getArgument() const final { return "test-data-flow"; } + StringRef getDescription() const final { + return "Print the actions taken during a dataflow analysis."; + } + void runOnOperation() override { + llvm::errs() << "Testing : " << getOperation()->getAttr("test.name") + << "\n"; + TestAnalysis analysis(getOperation().getContext()); + analysis.run(getOperation()); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestDataFlowPass() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -69,6 +69,7 @@ void registerTestControlFlowSink(); void registerTestGpuSerializeToCubinPass(); void registerTestGpuSerializeToHsacoPass(); +void registerTestDataFlowPass(); void registerTestDataLayoutQuery(); void registerTestDecomposeCallGraphTypes(); void registerTestDiagnosticsPass(); @@ -165,6 +166,7 @@ mlir::test::registerTestGpuSerializeToHsacoPass(); #endif mlir::test::registerTestDecomposeCallGraphTypes(); + mlir::test::registerTestDataFlowPass(); mlir::test::registerTestDataLayoutQuery(); mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass();