Index: mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -726,24 +726,45 @@ PatternRewriter &rewriter) const override; }; -/// Canonicalize AffineMinOp operations in the context of enclosing scf.for and -/// scf.parallel by: -/// 1. building an affine map where uses of the induction variable of a loop -/// are replaced by either the min (i.e. `%lb`) of the max -/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`) expression, depending -/// on whether the induction variable is used with a positive or negative -/// coefficient. +/// Canonicalize AffineMinOp operations in the context of ops with a +/// known range by: +/// 1. building an affine map where uses of the known ops are replaced by +/// their min value based on the lambda `substituteFn`. /// 2. checking whether any of the results of this affine map is known to be /// greater than all other results. /// 3. replacing the AffineMinOp by the result of (2). -// TODO: move to a more appropriate place when it is determined. For now Linalg -// depends both on Affine and SCF but they do not depend on each other. -struct AffineMinSCFCanonicalizationPattern +struct AffineMinRangeCanonicalizationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - + using SubstituteFn = std::function &, + SmallVectorImpl &)>; + AffineMinRangeCanonicalizationPattern(MLIRContext *context, + SubstituteFn substituteFn) + : OpRewritePattern(context), substituteFn(substituteFn) {} LogicalResult matchAndRewrite(AffineMinOp minOp, PatternRewriter &rewriter) const override; + +protected: + // Lambda to substitute a dimension based on the semantic of its definition. + SubstituteFn substituteFn; +}; + +/// Replace the induction variable of a scf loop by the min (i.e. `%lb`) or the +/// max(i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`) expression, +/// depending on whether the induction variable is used with a positive or +/// negative coefficient. +// TODO: move to a more appropriate place when it is determined. For now Linalg +// depends both on Affine and SCF but they do not depend on each other. +AffineExpr substituteSCFOps(Value dim, AffineExpr expr, AffineExpr dimExpr, + SmallVectorImpl &dims, + SmallVectorImpl &symbols); + +/// Specialized version of `AffineMinRangeCanonicalizationPattern` pattern +/// handling scf ops. +struct AffineMinSCFCanonicalizationPattern + : public AffineMinRangeCanonicalizationPattern { + AffineMinSCFCanonicalizationPattern(MLIRContext *context) + : AffineMinRangeCanonicalizationPattern(context, substituteSCFOps) {} }; /// Converts Convolution op into vector contraction. Index: mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -358,10 +358,32 @@ return ee; } -/// Traverse the `dims` and substitute known min or max expressions in place of -/// induction variables in `exprs`. -static AffineMap substitute(AffineMap map, SmallVectorImpl &dims, - SmallVectorImpl &symbols) { +/// substitute known min or max expressions in place of induction variables in +/// `exprs`. +AffineExpr mlir::linalg::substituteSCFOps(Value dim, AffineExpr expr, + AffineExpr dimExpr, + SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + AffineExpr substitutedExpr; + if (auto forOp = scf::getForInductionVarOwner(dim)) { + substitutedExpr = + substituteLoopInExpr(expr, dimExpr, forOp.lowerBound(), + forOp.upperBound(), forOp.step(), dims, symbols); + } else if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) { + for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; ++idx) + substitutedExpr = + substituteLoopInExpr(expr, dimExpr, parallelForOp.lowerBound()[idx], + parallelForOp.upperBound()[idx], + parallelForOp.step()[idx], dims, symbols); + } + return substitutedExpr; +} + +/// Traverse the `dims` and substitute known min or max expressions in `exprs`. +static AffineMap +substitute(AffineMap map, SmallVectorImpl &dims, + SmallVectorImpl &symbols, + AffineMinRangeCanonicalizationPattern::SubstituteFn substituteFn) { auto exprs = llvm::to_vector<4>(map.getResults()); for (AffineExpr &expr : exprs) { bool substituted = true; @@ -371,19 +393,8 @@ Value dim = dims[dimIdx]; AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); - AffineExpr substitutedExpr; - if (auto forOp = scf::getForInductionVarOwner(dim)) - substitutedExpr = substituteLoopInExpr( - expr, dimExpr, forOp.lowerBound(), forOp.upperBound(), - forOp.step(), dims, symbols); - - if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) - for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; - ++idx) - substitutedExpr = substituteLoopInExpr( - expr, dimExpr, parallelForOp.lowerBound()[idx], - parallelForOp.upperBound()[idx], parallelForOp.step()[idx], - dims, symbols); + AffineExpr substitutedExpr = + substituteFn(dim, expr, dimExpr, dims, symbols); if (!substitutedExpr) continue; @@ -420,14 +431,14 @@ return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); } -LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( +LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite( AffineMinOp minOp, PatternRewriter &rewriter) const { LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() << "\n"); SmallVector dims(minOp.getDimOperands()), symbols(minOp.getSymbolOperands()); - AffineMap map = substitute(minOp.getAffineMap(), dims, symbols); + AffineMap map = substitute(minOp.getAffineMap(), dims, symbols, substituteFn); LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");