diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -502,6 +502,26 @@ PatternRewriter &rewriter) const override; }; +/// Canonicalize AffineMinOp operations in the context of enclosing scf.for and +/// scf.parallel by: +/// 1. building an affine map where uses of the induction variable of a loop +/// are replaced by either the min (i.e. `%lb`) of the max +/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`) expression, depending +/// on whether the induction variable is used with a positive or negative +/// coefficient. +/// 2. checking whether any of the results of this affine map is known to be +/// greater than all other results. +/// 3. replacing the AffineMinOp by the result of (2). +// TODO: move to a more appropriate place when it is determined. For now Linalg +// depends both on Affine and SCF but they do not depend on each other. +struct AffineMinSCFCanonicalizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineMinOp minOp, + PatternRewriter &rewriter) const override; +}; + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// @@ -519,6 +539,7 @@ Operation *op, ArrayRef stage1Patterns, const OwningRewritePatternList &stage2Patterns, function_ref stage3Lambda = nullptr); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -115,9 +115,20 @@ /// This method substitutes any uses of dimensions and symbols (e.g. /// dim#0 with dimReplacements[0]) and returns the modified expression tree. + /// This is a dense replacement method: a replacement must be specified for + /// every single dim and symbol. AffineExpr replaceDimsAndSymbols(ArrayRef dimReplacements, ArrayRef symReplacements) const; + /// Sparse replace method. Replace `expr` by `replacement` and return the + /// modified expression tree. + AffineExpr replace(AffineExpr expr, AffineExpr replacement) const; + + /// Sparse replace method. If `*this` appears in `map` replaces it by + /// `map[*this]` and return the modified expression tree. Otherwise traverse + /// `*this` and apply replace with `map` on its subexpressions. + AffineExpr replace(const DenseMap &map) const; + /// Replace symbols[0 .. numDims - 1] by /// symbols[shift .. shift + numDims - 1]. AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -36,6 +36,7 @@ using namespace mlir::linalg; #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") + //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// @@ -235,3 +236,178 @@ } return success(); } + +/// Traverse `e` and return an AffineExpr where all occurrences of `dim` have +/// been replaced by either: +/// - `min` if `positivePath` is true when we reach an occurrence of `dim` +/// - `max` if `positivePath` is true when we reach an occurrence of `dim` +/// `positivePath` is negated each time we hit a multiplicative or divisive +/// binary op with a constant negative coefficient. +static AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, + AffineExpr max, bool positivePath = true) { + if (e == dim) { + if (positivePath) + return min; + if (!positivePath) + return max; + } + if (auto bin = e.dyn_cast()) { + AffineExpr lhs = bin.getLHS(); + AffineExpr rhs = bin.getRHS(); + if (bin.getKind() == mlir::AffineExprKind::Add) + return substWithMin(lhs, dim, min, max, positivePath) + + substWithMin(rhs, dim, min, max, positivePath); + + auto c1 = bin.getLHS().dyn_cast(); + auto c2 = bin.getRHS().dyn_cast(); + if (c1 && c1.getValue() < 0) + return getAffineBinaryOpExpr( + bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath)); + if (c2 && c2.getValue() < 0) + return getAffineBinaryOpExpr( + bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2); + return getAffineBinaryOpExpr( + bin.getKind(), substWithMin(lhs, dim, min, max, positivePath), + substWithMin(rhs, dim, min, max, positivePath)); + } + return e; +} + +/// Gieven a `forOp`, append its lower and upper bounds to `dims` and its step +/// to `symbols`. Also create new AffineDimExpr and AffineSymbolExpr with +/// positions matchine the newly appended values and store them in `lb`, `ub` +/// and `step` respectively. +static AffineExpr substituteLoopInExpr(AffineExpr expr, unsigned dimIdx, + Value lbVal, Value ubVal, Value stepVal, + SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + MLIRContext *ctx = lbVal.getContext(); + AffineExpr lb = getAffineDimExpr(dims.size(), ctx); + dims.push_back(lbVal); + AffineExpr ub = getAffineDimExpr(dims.size(), ctx); + dims.push_back(ubVal); + AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); + symbols.push_back(stepVal); + LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); + AffineExpr d = getAffineDimExpr(dimIdx, expr.getContext()); + AffineExpr ee = + substWithMin(expr, d, lb, lb + step * ((ub - 1) - lb).floorDiv(step)); + LLVM_DEBUG(DBGS() << "After: " << expr << "\n"); + return ee; +} + +/// Traverse the `dims` and substitute known min or max expressions in place of +/// induction variables in `exprs`. +static void substitute(SmallVectorImpl &exprs, + SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + for (AffineExpr &expr : exprs) { + bool substituted = true; + while (substituted) { + substituted = false; + for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { + Value dim = dims[dimIdx]; + LLVM_DEBUG(DBGS() << "Subst: " << dim << "\n"); + AffineExpr substitutedExpr; + if (auto forOp = scf::getForInductionVarOwner(dim)) + substitutedExpr = substituteLoopInExpr( + expr, dimIdx, forOp.lowerBound(), forOp.upperBound(), + forOp.step(), dims, symbols); + + if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) + for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; + ++idx) + substitutedExpr = substituteLoopInExpr( + expr, dimIdx, parallelForOp.lowerBound()[idx], + parallelForOp.upperBound()[idx], parallelForOp.step()[idx], + dims, symbols); + + if (!substitutedExpr) + continue; + + substituted = (substitutedExpr != expr); + expr = substitutedExpr; + } + } + + // Cleanup and simplify the results. + // This needs to happen outside of the loop iterating on dims.size() since + // it modifies dims. + SmallVector operands(dims.begin(), dims.end()); + operands.append(symbols.begin(), symbols.end()); + auto map = AffineMap::get(dims.size(), symbols.size(), exprs, + exprs.front().getContext()); + + LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n"); + + // Pull in affine.apply operations and compose them fully into the + // result. + fullyComposeAffineMapAndOperands(&map, &operands); + canonicalizeMapAndOperands(&map, &operands); + map = simplifyAffineMap(map); + // Assign the results. + exprs.assign(map.getResults().begin(), map.getResults().end()); + dims.assign(operands.begin(), operands.begin() + map.getNumDims()); + symbols.assign(operands.begin() + map.getNumDims(), operands.end()); + + LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); + } +} + +LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( + AffineMinOp minOp, PatternRewriter &rewriter) const { + LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() + << "\n"); + + auto exprs = llvm::to_vector<4>(minOp.getAffineMap().getResults()); + SmallVector dims(minOp.getDimOperands()), + symbols(minOp.getSymbolOperands()); + substitute(exprs, dims, symbols); + + MLIRContext *ctx = minOp.getContext(); + auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx); + LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); + + // Check whether any of the expressions, when subtracted from all other + // expressions, produces only >= 0 constants. If so, it is the min. + for (auto e : map.getResults()) { + LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); + if (!e.isSymbolicOrConstant()) + continue; + + auto isNonPositive = [](AffineExpr e) { + if (auto cst = e.dyn_cast()) + return cst.getValue() < 0; + return true; + }; + + // Build the subMap and check everything is statically known to be + // positive. + SmallVector subExprs; + subExprs.reserve(map.getNumResults()); + for (auto ee : map.getResults()) + subExprs.push_back(ee - e); + AffineMap subMap = simplifyAffineMap( + AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); + LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); + if (llvm::any_of(subMap.getResults(), isNonPositive)) + return failure(); + + // Static min found. + if (auto cst = e.dyn_cast()) { + rewriter.replaceOpWithNewOp(minOp, cst.getValue()); + } else { + auto resultMap = + AffineMap::get(map.getNumDims(), map.getNumSymbols(), {e}, ctx); + SmallVector resultOperands = dims; + resultOperands.append(symbols.begin(), symbols.end()); + canonicalizeMapAndOperands(&resultMap, &resultOperands); + resultMap = simplifyAffineMap(resultMap); + rewriter.replaceOpWithNewOp(minOp, resultMap, + resultOperands); + } + return success(); + } + + return failure(); +} diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -101,6 +101,37 @@ return replaceDimsAndSymbols({}, symbols); } +/// Sparse replace method. Return the modified expression tree. +AffineExpr +AffineExpr::replace(const DenseMap &map) const { + auto it = map.find(*this); + if (it != map.end()) + return it->second; + switch (getKind()) { + default: + return *this; + case AffineExprKind::Add: + case AffineExprKind::Mul: + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: + case AffineExprKind::Mod: + auto binOp = cast(); + auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); + auto newLHS = lhs.replace(map); + auto newRHS = rhs.replace(map); + if (newLHS == lhs && newRHS == rhs) + return *this; + return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); + } + llvm_unreachable("Unknown AffineExpr"); +} + +/// Sparse replace method. Return the modified expression tree. +AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const { + DenseMap map; + map.insert(std::make_pair(expr, replacement)); + return replace(map); +} /// Returns true if this expression is made out of only symbols and /// constants (no dimensional identifiers). bool AffineExpr::isSymbolicOrConstant() const { diff --git a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir @@ -0,0 +1,143 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns | FileCheck %s + +// CHECK-LABEL: scf_for +func @scf_for(%A : memref, %step : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c7 = constant 7 : index + %c4 = constant 4 : index + %c16 = constant 16 : index + %c1024 = constant 1024 : index + + // CHECK: scf.for + // CHECK-NEXT: %[[C2:.*]] = constant 2 : index + // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] + // CHECK-NEXT: store %[[C2I64]], %{{.*}}[] : memref + scf.for %i = %c0 to %c4 step %c2 { + %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // CHECK: scf.for + // CHECK-NEXT: %[[C2:.*]] = constant 2 : index + // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] + // CHECK-NEXT: store %[[C2I64]], %{{.*}}[] : memref + scf.for %i = %c1 to %c7 step %c2 { + %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7] + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This should not canonicalize because: 4 - %i may take the value 1 < 2. + // CHECK: scf.for + // CHECK: affine.min + // CHECK: index_cast + scf.for %i = %c1 to %c4 step %c2 { + %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c4] + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This should not canonicalize because: 16 - %i may take the value 15 < 1024. + // CHECK: scf.for + // CHECK: affine.min + // CHECK: index_cast + scf.for %i = %c1 to %c16 step %c1024 { + %1 = affine.min affine_map<(d0) -> (1024, 16 - d0)> (%i) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations: `((s0 * 42 - 1) floordiv s0) * s0` + // should evaluate to 41 * s0. + // Note that this may require positivity assumptions on `s0`. + // Revisit when support is added. + // CHECK: scf.for + // CHECK: index_cast + %ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step) + scf.for %i = %c0 to %ub step %step { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d1 - d2)> (%step, %ub, %i) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations. + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations: ` -(((s0 * s0 - 1) floordiv s0) * s0)` + // should evaluate to (s0 - 1) * s0. + // Note that this may require positivity assumptions on `s0`. + // Revisit when support is added. + // CHECK: scf.for + // CHECK: index_cast + %ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step] + scf.for %i = %c0 to %ub2 step %step { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + return +} + +// CHECK-LABEL: scf_parallel +func @scf_parallel(%A : memref, %step : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c7 = constant 7 : index + %c4 = constant 4 : index + + // CHECK: scf.parallel + // CHECK-NEXT: %[[C2:.*]] = constant 2 : index + // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] + // CHECK-NEXT: store %[[C2I64]], %{{.*}}[] : memref + scf.parallel (%i) = (%c0) to (%c4) step (%c2) { + %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // CHECK: scf.parallel + // CHECK-NEXT: %[[C2:.*]] = constant 2 : index + // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] + // CHECK-NEXT: store %[[C2I64]], %{{.*}}[] : memref + scf.parallel (%i) = (%c1) to (%c7) step (%c2) { + %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7] + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations. + // This affine map does not currently evaluate to (0, 0): + // (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * 42) mod s0) + // TODO: Revisit when support is added. + // CHECK: scf.parallel + // CHECK: index_cast + %ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step) + scf.parallel (%i) = (%c0) to (%ub) step (%step) { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations. + // This affine map does not currently evaluate to (0, 0): + // (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * s0) mod s0) + // TODO: Revisit when support is added. + // CHECK: scf.parallel + // CHECK: index_cast + %ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step] + scf.parallel (%i) = (%c0) to (%ub2) step (%step) { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + return +} diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -59,6 +59,10 @@ llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " "in vector.contract form"), llvm::cl::init(false)}; + Option testAffineMinSCFCanonicalizationPatterns{ + *this, "test-affine-min-scf-canonicalization-patterns", + llvm::cl::desc("Test affine-min + scf canonicalization patterns."), + llvm::cl::init(false)}; }; } // end anonymous namespace @@ -316,6 +320,15 @@ applyPatternsAndFoldGreedily(funcOp, patterns); } +static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { + OwningRewritePatternList foldPattern; + foldPattern.insert(funcOp.getContext()); + // Explicitly walk and apply the pattern locally to avoid more general folding + // on the rest of the IR. + funcOp.walk([&foldPattern](AffineMinOp minOp) { + applyOpPatternsAndFold(minOp, foldPattern); + }); +} /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { auto lambda = [&](void *) { @@ -341,6 +354,8 @@ return applyVectorTransferForwardingPatterns(getFunction()); if (testGenericToVectorPattern) return applyContractionToVectorPatterns(getFunction()); + if (testAffineMinSCFCanonicalizationPatterns) + return applyAffineMinSCFCanonicalizationPatterns(getFunction()); } namespace mlir {