diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_AFFINE_UTILS_H #define MLIR_DIALECT_AFFINE_UTILS_H +#include "mlir/IR/AffineExpr.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" @@ -130,6 +131,15 @@ /// early if the op is already in a normalized form. void normalizeAffineParallel(AffineParallelOp op); +/// Traverse `e` and return an AffineExpr where all occurrences of `dim` have +/// been replaced by either: +/// - `min` if `positivePath` is true when we reach an occurrence of `dim` +/// - `max` if `positivePath` is true when we reach an occurrence of `dim` +/// `positivePath` is negated each time we hit a multiplicative or divisive +/// binary op with a constant negative coefficient. +AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, + AffineExpr max, bool positivePath = true); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_UTILS_H diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -226,3 +226,30 @@ return success(); } + +// Return the min expr after replacing the given dim. +AffineExpr mlir::substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, + AffineExpr max, bool positivePath) { + if (e == dim) + return positivePath ? min : max; + if (auto bin = e.dyn_cast()) { + AffineExpr lhs = bin.getLHS(); + AffineExpr rhs = bin.getRHS(); + if (bin.getKind() == mlir::AffineExprKind::Add) + return substWithMin(lhs, dim, min, max, positivePath) + + substWithMin(rhs, dim, min, max, positivePath); + + auto c1 = bin.getLHS().dyn_cast(); + auto c2 = bin.getRHS().dyn_cast(); + if (c1 && c1.getValue() < 0) + return getAffineBinaryOpExpr( + bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath)); + if (c2 && c2.getValue() < 0) + return getAffineBinaryOpExpr( + bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2); + return getAffineBinaryOpExpr( + bin.getKind(), substWithMin(lhs, dim, min, max, positivePath), + substWithMin(rhs, dim, min, max, positivePath)); + } + return e; +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -332,38 +333,6 @@ return success(); } -/// Traverse `e` and return an AffineExpr where all occurrences of `dim` have -/// been replaced by either: -/// - `min` if `positivePath` is true when we reach an occurrence of `dim` -/// - `max` if `positivePath` is true when we reach an occurrence of `dim` -/// `positivePath` is negated each time we hit a multiplicative or divisive -/// binary op with a constant negative coefficient. -static AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, - AffineExpr max, bool positivePath = true) { - if (e == dim) - return positivePath ? min : max; - if (auto bin = e.dyn_cast()) { - AffineExpr lhs = bin.getLHS(); - AffineExpr rhs = bin.getRHS(); - if (bin.getKind() == mlir::AffineExprKind::Add) - return substWithMin(lhs, dim, min, max, positivePath) + - substWithMin(rhs, dim, min, max, positivePath); - - auto c1 = bin.getLHS().dyn_cast(); - auto c2 = bin.getRHS().dyn_cast(); - if (c1 && c1.getValue() < 0) - return getAffineBinaryOpExpr( - bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath)); - if (c2 && c2.getValue() < 0) - return getAffineBinaryOpExpr( - bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2); - return getAffineBinaryOpExpr( - bin.getKind(), substWithMin(lhs, dim, min, max, positivePath), - substWithMin(rhs, dim, min, max, positivePath)); - } - return e; -} - /// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and /// `ubVal` to `dims` and `stepVal` to `symbols`. /// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`)