diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -98,6 +98,18 @@ LinalgOp dstLinalgOp, Value view) const; + /// Returns true if the two operations have the specified dependence from + /// `srcLinalgOp` to `dstLinalgOp`. + bool hasDependenceFrom(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, + ArrayRef depTypes = { + DependenceType::RAW, DependenceType::WAW}) const; + + /// Returns true if the `linalgOp` has dependences into or from it. + bool hasDependentOperations(LinalgOp linalgOp, + ArrayRef depTypes = { + DependenceType::RAW, + DependenceType::WAW}) const; + private: // Keep dependences in both directions, this is not just a performance gain // but it also reduces usage errors. diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -162,7 +162,7 @@ } // WAW graph for (auto dstView : dst.getOutputBuffers()) { // W - if (aliases.alias(srcView, dstView)) { // if alias, fill WAW + if (aliases.alias(srcView, dstView)) { // if alias, fill WAW addDependenceElem(DependenceType::WAW, LinalgOpView{src.getOperation(), srcView}, LinalgOpView{dst.getOperation(), dstView}); @@ -180,7 +180,7 @@ } // WAR graph for (auto dstView : dst.getOutputBuffers()) { // W - if (aliases.alias(srcView, dstView)) { // if alias, fill WAR + if (aliases.alias(srcView, dstView)) { // if alias, fill WAR addDependenceElem(DependenceType::WAR, LinalgOpView{src.getOperation(), srcView}, LinalgOpView{dst.getOperation(), dstView}); @@ -242,3 +242,26 @@ } return res; } + +bool LinalgDependenceGraph::hasDependenceFrom( + LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, + ArrayRef depTypes) const { + for (auto dep : depTypes) { + for (auto dependence : getDependencesInto(dstLinalgOp, dep)) { + if (dependence.dependentOpView.op == srcLinalgOp) + return true; + } + } + return false; +} + +bool LinalgDependenceGraph::hasDependentOperations( + LinalgOp linalgOp, + ArrayRef depTypes) const { + for (auto dep : depTypes) { + if (!getDependencesFrom(linalgOp, dep).empty() || + !getDependencesInto(linalgOp, dep).empty()) + return true; + } + return false; +}