diff --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h @@ -39,6 +39,16 @@ using LoopMatcherFn = function_ref; +/// Match "for loop"-like operations from the SCF dialect. +LogicalResult matchForLikeLoop(Value iv, OpFoldResult &lb, OpFoldResult &ub, + OpFoldResult &step); + +/// Populate the given constraint set with induction variable constraints of a +/// "for" loop with the given range and step. +LogicalResult addLoopRangeConstraints(FlatAffineValueConstraints &cstr, + Value iv, OpFoldResult lb, + OpFoldResult ub, OpFoldResult step); + /// Try to canonicalize the given affine.min/max operation in the context of /// for `loops` with a known range. /// diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -158,41 +158,7 @@ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub, - OpFoldResult &step) { - if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { - lb = forOp.getLowerBound(); - ub = forOp.getUpperBound(); - step = forOp.getStep(); - return success(); - } - if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { - for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { - if (parOp.getInductionVars()[idx] == iv) { - lb = parOp.getLowerBound()[idx]; - ub = parOp.getUpperBound()[idx]; - step = parOp.getStep()[idx]; - return success(); - } - } - return failure(); - } - if (scf::ForeachThreadOp foreachThreadOp = - scf::getForeachThreadOpThreadIndexOwner(iv)) { - for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) { - if (foreachThreadOp.getThreadIndices()[idx] == iv) { - lb = OpBuilder(iv.getContext()).getIndexAttr(0); - ub = foreachThreadOp.getNumThreads()[idx]; - step = OpBuilder(iv.getContext()).getIndexAttr(1); - return success(); - } - } - return failure(); - } - return failure(); - }; - - return scf::canonicalizeMinMaxOpInLoop(rewriter, op, loopMatcher); + return scf::canonicalizeMinMaxOpInLoop(rewriter, op, scf::matchForLikeLoop); } }; diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp --- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp @@ -29,6 +29,40 @@ using namespace mlir; using namespace presburger; +LogicalResult scf::matchForLikeLoop(Value iv, OpFoldResult &lb, + OpFoldResult &ub, OpFoldResult &step) { + if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { + lb = forOp.getLowerBound(); + ub = forOp.getUpperBound(); + step = forOp.getStep(); + return success(); + } + if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { + for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { + if (parOp.getInductionVars()[idx] == iv) { + lb = parOp.getLowerBound()[idx]; + ub = parOp.getUpperBound()[idx]; + step = parOp.getStep()[idx]; + return success(); + } + } + return failure(); + } + if (scf::ForeachThreadOp foreachThreadOp = + scf::getForeachThreadOpThreadIndexOwner(iv)) { + for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) { + if (foreachThreadOp.getThreadIndices()[idx] == iv) { + lb = OpBuilder(iv.getContext()).getIndexAttr(0); + ub = foreachThreadOp.getNumThreads()[idx]; + step = OpBuilder(iv.getContext()).getIndexAttr(1); + return success(); + } + } + return failure(); + } + return failure(); +} + static FailureOr canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, FlatAffineValueConstraints constraints) { @@ -42,37 +76,38 @@ op, simplified->getAffineMap(), simplified->getOperands()); } -static LogicalResult -addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv, - OpFoldResult lb, OpFoldResult ub, OpFoldResult step, - RewriterBase &rewriter) { +LogicalResult scf::addLoopRangeConstraints(FlatAffineValueConstraints &cstr, + Value iv, OpFoldResult lb, + OpFoldResult ub, OpFoldResult step) { + Builder b(iv.getContext()); + // IntegerPolyhedron does not support semi-affine expressions. // Therefore, only constant step values are supported. auto stepInt = getConstantIntValue(step); if (!stepInt) return failure(); - unsigned dimIv = constraints.appendDimVar(iv); + unsigned dimIv = cstr.appendDimVar(iv); auto lbv = lb.dyn_cast(); - unsigned symLb = lbv ? constraints.appendSymbolVar(lbv) - : constraints.appendSymbolVar(/*num=*/1); + unsigned symLb = + lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1); auto ubv = ub.dyn_cast(); - unsigned symUb = ubv ? constraints.appendSymbolVar(ubv) - : constraints.appendSymbolVar(/*num=*/1); + unsigned symUb = + ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1); // If loop lower/upper bounds are constant: Add EQ constraint. std::optional lbInt = getConstantIntValue(lb); std::optional ubInt = getConstantIntValue(ub); if (lbInt) - constraints.addBound(IntegerPolyhedron::EQ, symLb, *lbInt); + cstr.addBound(IntegerPolyhedron::EQ, symLb, *lbInt); if (ubInt) - constraints.addBound(IntegerPolyhedron::EQ, symUb, *ubInt); + cstr.addBound(IntegerPolyhedron::EQ, symUb, *ubInt); // Lower bound: iv >= lb (equiv.: iv - lb >= 0) - SmallVector ineqLb(constraints.getNumCols(), 0); + SmallVector ineqLb(cstr.getNumCols(), 0); ineqLb[dimIv] = 1; ineqLb[symLb] = -1; - constraints.addInequality(ineqLb); + cstr.addInequality(ineqLb); // Upper bound AffineExpr ivUb; @@ -81,26 +116,23 @@ // iv < lb + 1 // TODO: Try to derive this constraint by simplifying the expression in // the else-branch. - ivUb = - rewriter.getAffineSymbolExpr(symLb - constraints.getNumDimVars()) + 1; + ivUb = b.getAffineSymbolExpr(symLb - cstr.getNumDimVars()) + 1; } else { // The loop may have more than one iteration. // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 AffineExpr exprLb = - lbInt - ? rewriter.getAffineConstantExpr(*lbInt) - : rewriter.getAffineSymbolExpr(symLb - constraints.getNumDimVars()); + lbInt ? b.getAffineConstantExpr(*lbInt) + : b.getAffineSymbolExpr(symLb - cstr.getNumDimVars()); AffineExpr exprUb = - ubInt - ? rewriter.getAffineConstantExpr(*ubInt) - : rewriter.getAffineSymbolExpr(symUb - constraints.getNumDimVars()); + ubInt ? b.getAffineConstantExpr(*ubInt) + : b.getAffineSymbolExpr(symUb - cstr.getNumDimVars()); ivUb = exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt))); } auto map = AffineMap::get( - /*dimCount=*/constraints.getNumDimVars(), - /*symbolCount=*/constraints.getNumSymbolVars(), /*result=*/ivUb); + /*dimCount=*/cstr.getNumDimVars(), + /*symbolCount=*/cstr.getNumSymbolVars(), /*result=*/ivUb); - return constraints.addBound(IntegerPolyhedron::UB, dimIv, map); + return cstr.addBound(IntegerPolyhedron::UB, dimIv, map); } /// Canonicalize min/max operations in the context of for loops with a known @@ -132,8 +164,7 @@ continue; allIvs.insert(iv); - if (failed( - addLoopRangeConstraints(constraints, iv, lb, ub, step, rewriter))) + if (failed(addLoopRangeConstraints(constraints, iv, lb, ub, step))) return failure(); }