Index: mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -10,6 +10,7 @@ #define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" @@ -933,24 +934,44 @@ PatternRewriter &rewriter) const override; }; -/// Canonicalize AffineMinOp operations in the context of enclosing scf.for and -/// scf.parallel by: -/// 1. building an affine map where uses of the induction variable of a loop -/// are replaced by either the min (i.e. `%lb`) of the max -/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`) expression, depending -/// on whether the induction variable is used with a positive or negative -/// coefficient. +using GetMinMaxExprFn = + std::function>( + Value value, SmallVectorImpl &dims, + SmallVectorImpl &symbols)>; + +/// Canonicalize AffineMinOp operations in the context of ops with a known range +/// by: +/// 1. building an affine map where uses of the known ops are replaced by +/// their min annd max expressions returned by the lambda `getMinMaxFn`. /// 2. checking whether any of the results of this affine map is known to be /// greater than all other results. /// 3. replacing the AffineMinOp by the result of (2). -// TODO: move to a more appropriate place when it is determined. For now Linalg -// depends both on Affine and SCF but they do not depend on each other. -struct AffineMinSCFCanonicalizationPattern +struct AffineMinRangeCanonicalizationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - + AffineMinRangeCanonicalizationPattern(MLIRContext *context, + GetMinMaxExprFn getMinMaxFn) + : OpRewritePattern(context), getMinMaxFn(getMinMaxFn) {} LogicalResult matchAndRewrite(AffineMinOp minOp, PatternRewriter &rewriter) const override; + +protected: + GetMinMaxExprFn getMinMaxFn; +}; + +/// Specialized version of `AffineMinRangeCanonicalizationPattern` pattern +/// using `getSCFMinMaxExpr` to know the min and max expression of induction +/// variables from scf loops. +// TODO: move to a more appropriate place when it is determined. For now Linalg +// depends both on Affine and SCF but they do not depend on each other. +struct AffineMinSCFCanonicalizationPattern + : public AffineMinRangeCanonicalizationPattern { + static Optional> + getMinMax(Value value, SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + return getSCFMinMaxExpr(value, dims, symbols); + } + AffineMinSCFCanonicalizationPattern(MLIRContext *context) + : AffineMinRangeCanonicalizationPattern(context, getMinMax) {} }; /// Helper struct to return the results of `substituteMin`. @@ -959,23 +980,22 @@ SmallVector dims; SmallVector symbols; }; -/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop -/// induction variables by new expressions involving the lower or upper bound: -/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is -/// replaced by the loop upper bound. -/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is -/// replaced by the loop lower bound. -/// All loop induction variables are iteratively replaced, unless a -/// `substituteOperation` hook is passed to more finely determine which -/// operations are substituted. + +/// Traverse the dims of the AffineMap of `affineMinOp` and substitute +/// dimensions with known range by new expressions involving the min or max +/// expression: +/// - If the AffineDimExpr mapped to a known value has a positive sign, it +/// is replaced by the min expression. +/// - If the AffineDimExpr mapped to a known value has a negative sign, it is +/// replaced by the max expression. +/// All known values are iteratively replaced. /// This is used as an intermediate step in computing bounding boxes and /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have /// positive values (positive orthant assumptions). /// Return a new AffineMap, dims and symbols that have been canonicalized and /// simplified. -AffineMapAndOperands substituteMin( - AffineMinOp affineMinOp, - llvm::function_ref substituteOperation = nullptr); +AffineMapAndOperands substituteMin(AffineMinOp affineMinOp, + GetMinMaxExprFn getMinMaxExpr); /// Converts Convolution op into vector contraction. /// Index: mlir/include/mlir/Dialect/SCF/Utils.h =================================================================== --- mlir/include/mlir/Dialect/SCF/Utils.h +++ mlir/include/mlir/Dialect/SCF/Utils.h @@ -14,12 +14,16 @@ #define MLIR_DIALECT_SCF_UTILS_H_ #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { class FuncOp; class Operation; class OpBuilder; class ValueRange; +class Value; +class AffineExpr; +class Operation; namespace scf { class IfOp; @@ -64,5 +68,14 @@ bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl &result); +/// Return the min/max expressions for `value` if it is an induction variable +/// from scf.for or scf.parallel loop. +/// if `loopFilter` is passed, the filter determines which loop to consider. +/// Other induction variables are ignored. +Optional> +getSCFMinMaxExpr(Value value, SmallVectorImpl &dims, + SmallVectorImpl &symbols, + llvm::function_ref loopFilter = nullptr); + } // end namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_H_ Index: mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -574,9 +574,15 @@ continue; } auto sliceMinOp = cast(op); + GetMinMaxExprFn getSCFMinMax = [&](Value value, + SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + return getSCFMinMaxExpr(value, dims, symbols, [&](Operation *op) { + return outer->isAncestor(op); + }); + }; // Perform the substitution of the operands of AffineMinOp. - auto mapAndOperands = substituteMin( - sliceMinOp, [&](Operation *op) { return outer->isAncestor(op); }); + auto mapAndOperands = substituteMin(sliceMinOp, getSCFMinMax); SmallVector resultOperands = mapAndOperands.dims; llvm::append_range(resultOperands, mapAndOperands.symbols); AffineMap map = mapAndOperands.map; Index: mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -514,37 +514,11 @@ return success(); } -/// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and -/// `ubVal` to `dims` and `stepVal` to `symbols`. -/// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`) -/// with positions matching the newly appended values. Substitute occurrences of -/// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression -/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether -/// the induction variable is used with a positive or negative coefficient. -static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr, - Value lbVal, Value ubVal, Value stepVal, - SmallVectorImpl &dims, - SmallVectorImpl &symbols) { - MLIRContext *ctx = lbVal.getContext(); - AffineExpr lb = getAffineDimExpr(dims.size(), ctx); - dims.push_back(lbVal); - AffineExpr ub = getAffineDimExpr(dims.size(), ctx); - dims.push_back(ubVal); - AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); - symbols.push_back(stepVal); - LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); - AffineExpr ee = substWithMin(expr, dimExpr, lb, - lb + step * ((ub - 1) - lb).floorDiv(step)); - LLVM_DEBUG(DBGS() << "After: " << expr << "\n"); - return ee; -} - -/// Traverse the `dims` and substitute known min or max expressions in place of -/// induction variables in `exprs`. -static AffineMap substitute( - AffineMap map, SmallVectorImpl &dims, - SmallVectorImpl &symbols, - llvm::function_ref substituteOperation = nullptr) { +/// Traverse the `dims` and substitute known min or max expressions returned by +/// the lambda |getMinMaxExpr|. +static AffineMap substitute(AffineMap map, SmallVectorImpl &dims, + SmallVectorImpl &symbols, + GetMinMaxExprFn getMinMaxExpr) { auto exprs = llvm::to_vector<4>(map.getResults()); for (AffineExpr &expr : exprs) { bool substituted = true; @@ -552,27 +526,18 @@ substituted = false; for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { Value dim = dims[dimIdx]; + auto minMax = getMinMaxExpr(dim, dims, symbols); + if (!minMax) + continue; AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); - AffineExpr substitutedExpr; - if (auto forOp = scf::getForInductionVarOwner(dim)) - if (!substituteOperation || substituteOperation(forOp)) - substitutedExpr = substituteLoopInExpr( - expr, dimExpr, forOp.lowerBound(), forOp.upperBound(), - forOp.step(), dims, symbols); - - if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) - if (!substituteOperation || substituteOperation(parallelForOp)) - for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; - ++idx) - substitutedExpr = substituteLoopInExpr( - expr, dimExpr, parallelForOp.lowerBound()[idx], - parallelForOp.upperBound()[idx], parallelForOp.step()[idx], - dims, symbols); - - if (!substitutedExpr) - continue; - + LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); + // Substitute occurrences of `dimExpr` by either the min expression or + // the max expression depending on whether the value is used with a + // positive or negative coefficient. + AffineExpr substitutedExpr = + substWithMin(expr, dimExpr, minMax->first, minMax->second); + LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n"); substituted = (substitutedExpr != expr); expr = substitutedExpr; } @@ -610,37 +575,36 @@ return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); } -/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop -/// induction variables by new expressions involving the lower or upper bound: -/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is -/// replaced by the loop upper bound. -/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is -/// replaced by the loop lower bound. -/// All loop induction variables are iteratively replaced, unless a -/// `substituteOperation` hook is passed to more finely determine which -/// operations are substituted. +/// Traverse the dims of the AffineMap of `affineMinOp` and substitute +/// dimensions with known range by new expressions involving the min or max +/// expression: +/// - If the AffineDimExpr mapped to a known value has a positive sign, it +/// is replaced by the min expression. +/// - If the AffineDimExpr mapped to a known value has a negative sign, it is +/// replaced by the max expression. +/// All known values are iteratively replaced. /// This is used as an intermediate step in computing bounding boxes and /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have /// positive values (positive orthant assumptions). /// Return a new AffineMap, dims and symbols that have been canonicalized and /// simplified. -AffineMapAndOperands mlir::linalg::substituteMin( - AffineMinOp affineMinOp, - llvm::function_ref substituteOperation) { +AffineMapAndOperands +mlir::linalg::substituteMin(AffineMinOp affineMinOp, + GetMinMaxExprFn getMinMaxExpr) { AffineMapAndOperands res{affineMinOp.getAffineMap(), SmallVector(affineMinOp.getDimOperands()), SmallVector(affineMinOp.getSymbolOperands())}; res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols, - substituteOperation); + getMinMaxExpr); return res; } -LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( +LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite( AffineMinOp minOp, PatternRewriter &rewriter) const { LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() << "\n"); - auto affineMapAndOperands = substituteMin(minOp); + auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn); AffineMap map = affineMapAndOperands.map; LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); Index: mlir/lib/Dialect/SCF/Transforms/Utils.cpp =================================================================== --- mlir/lib/Dialect/SCF/Transforms/Utils.cpp +++ mlir/lib/Dialect/SCF/Transforms/Utils.cpp @@ -145,3 +145,43 @@ } return rootEnclosesPloops; } + +/// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and +/// `ubVal` to `dims` and `stepVal` to `symbols`. +/// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`) +/// with positions matching the newly appended values. Then create a min +/// expression (i.e. `%lb`) and a max expression +/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`. +static std::pair +getMinMaxLoopIndVar(Value lbVal, Value ubVal, Value stepVal, + SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + MLIRContext *ctx = lbVal.getContext(); + AffineExpr lb = getAffineDimExpr(dims.size(), ctx); + dims.push_back(lbVal); + AffineExpr ub = getAffineDimExpr(dims.size(), ctx); + dims.push_back(ubVal); + AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); + symbols.push_back(stepVal); + return std::make_pair(lb, lb + step * ((ub - 1) - lb).floorDiv(step)); +} + +/// Return the min/max expressions for `value` if it is an induction variable +/// from scf.for or scf.parallel loop. +/// if `loopFilter` is passed, the filter determines which loop to consider. +/// Other induction variables are ignored. +Optional> mlir::getSCFMinMaxExpr( + Value value, SmallVectorImpl &dims, SmallVectorImpl &symbols, + llvm::function_ref substituteOperation) { + if (auto forOp = scf::getForInductionVarOwner(value)) + return getMinMaxLoopIndVar(forOp.lowerBound(), forOp.upperBound(), + forOp.step(), dims, symbols); + + if (auto parallelForOp = scf::getParallelForInductionVarOwner(value)) + for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; ++idx) + if (parallelForOp.getInductionVars()[idx] == value) + return getMinMaxLoopIndVar(parallelForOp.lowerBound()[idx], + parallelForOp.upperBound()[idx], + parallelForOp.step()[idx], dims, symbols); + return {}; +}