diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -104,12 +104,42 @@ ArrayRef depTypes = { DependenceType::RAW, DependenceType::WAW}) const; + /// Returns true if the `linalgOp` has dependences into it. + bool hasDependentOperationsInto(LinalgOp linalgOp, + ArrayRef depTypes = { + DependenceType::RAW, + DependenceType::WAW}) const; + + /// Returns true if the `linalgOp` has dependences from it. + bool hasDependentOperationsFrom(LinalgOp linalgOp, + ArrayRef depTypes = { + DependenceType::RAW, + DependenceType::WAW}) const; + /// Returns true if the `linalgOp` has dependences into or from it. bool hasDependentOperations(LinalgOp linalgOp, ArrayRef depTypes = { DependenceType::RAW, DependenceType::WAW}) const; + /// Returns all operations that have a dependence into `linalgOp` of types + /// listed in `depTypes`. + SmallVector getDependentOperationsInto( + LinalgOp linalgOp, ArrayRef depTypes = { + DependenceType::RAW, DependenceType::WAW}) const; + + /// Returns all operations that have a dependence from `linalgOp` of types + /// listed in `depTypes`. + SmallVector getDependentOperationsFrom( + LinalgOp linalgOp, ArrayRef depTypes = { + DependenceType::RAW, DependenceType::WAW}) const; + + /// Returns all dependent operations (into and from) given `operation`. + SmallVector + getDependentOperations(LinalgOp linalgOp, + ArrayRef depTypes = { + DependenceType::RAW, DependenceType::WAW}) const; + private: // Keep dependences in both directions, this is not just a performance gain // but it also reduces usage errors. diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -251,13 +251,64 @@ return false; } -bool LinalgDependenceGraph::hasDependentOperations( +bool LinalgDependenceGraph::hasDependentOperationsFrom( + LinalgOp linalgOp, + ArrayRef depTypes) const { + for (auto dep : depTypes) { + if (!getDependencesFrom(linalgOp, dep).empty()) + return true; + } + return false; +} + +bool LinalgDependenceGraph::hasDependentOperationsInto( LinalgOp linalgOp, ArrayRef depTypes) const { for (auto dep : depTypes) { - if (!getDependencesFrom(linalgOp, dep).empty() || - !getDependencesInto(linalgOp, dep).empty()) + if (!getDependencesInto(linalgOp, dep).empty()) return true; } return false; } + +bool LinalgDependenceGraph::hasDependentOperations( + LinalgOp linalgOp, ArrayRef depTypes) const { + return hasDependentOperationsInto(linalgOp, depTypes) || + hasDependentOperationsFrom(linalgOp, depTypes); +} + +SmallVector +LinalgDependenceGraph::getDependentOperationsInto( + LinalgOp linalgOp, ArrayRef depTypes) const { + SmallVector + dependentOperations; + for (auto dependenceType : depTypes) { + auto dependencies = getDependencesInto(linalgOp, dependenceType); + dependentOperations.append(dependencies.begin(), dependencies.end()); + } + return dependentOperations; +} + +SmallVector +LinalgDependenceGraph::getDependentOperationsFrom( + LinalgOp linalgOp, ArrayRef depTypes) const { + SmallVector + dependentOperations; + for (auto dependenceType : depTypes) { + auto dependencies = getDependencesFrom(linalgOp, dependenceType); + dependentOperations.append(dependencies.begin(), dependencies.end()); + } + return dependentOperations; +} + +/// Returns all dependent operations (into and from) given `operation`. +SmallVector +LinalgDependenceGraph::getDependentOperations( + LinalgOp linalgOp, ArrayRef depTypes) const { + SmallVector dependentOperations = + getDependentOperationsInto(linalgOp, depTypes); + SmallVector t = + getDependentOperationsFrom(linalgOp, depTypes); + dependentOperations.append(t.begin(), t.end()); + return dependentOperations; +}