diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -497,6 +497,23 @@ PatternRewriter &rewriter) const override; }; +/// Canonicalize AffineMinOp operations in the context of enclosing scf.for and +/// scf.parallel by: +/// 1. building an affine map where uses of the induction variable of a loop +/// are replaced by a `%lb + %step * floorDiv(%iv - %lb, %step)` expression. +/// 2. checking whether any of the results of this affine map divides all the +/// other results (in which case it is also guaranteed to be the min). +/// 3. replacing the AffineMinOp by the result of (2). +// TODO: move to a more appropriate place when it is determined. For now Linalg +// depends both on Affine and SCF but they do not depend on each other. +struct AffineMinSCFCanonicalizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineMinOp minOp, + PatternRewriter &rewriter) const override; +}; + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// @@ -514,6 +531,7 @@ Operation *op, ArrayRef stage1Patterns, const OwningRewritePatternList &stage2Patterns, function_ref stage3Lambda = nullptr); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -115,9 +115,20 @@ /// This method substitutes any uses of dimensions and symbols (e.g. /// dim#0 with dimReplacements[0]) and returns the modified expression tree. + /// This is a dense replacement method: a replacement must be specified for + /// every single dim and symbol. AffineExpr replaceDimsAndSymbols(ArrayRef dimReplacements, ArrayRef symReplacements) const; + /// Sparse replace method. Replace `expr` by `replacement` and return the + /// modified expression tree. + AffineExpr replace(AffineExpr expr, AffineExpr replacement) const; + + /// Sparse replace method. If `*this` appears in `map` replaces it by + /// `map[*this]` and return the modified expression tree. Otherwise traverse + /// `*this` and apply replace with `map` on its subexpressions. + AffineExpr replace(const DenseMap &map) const; + AffineExpr operator+(int64_t v) const; AffineExpr operator+(AffineExpr other) const; AffineExpr operator-() const; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -36,6 +36,7 @@ using namespace mlir::linalg; #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") + //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// @@ -235,3 +236,147 @@ } return success(); } + +/// Substitute the AffineExprDim at position `dimIdx`, which corresponds to a +/// loop induction variable (e.g. scf.for %iv = %lb to %ub step %step) by the +/// AffineExpr representing `%lb + %step * floorDiv(%iv - %lb, %step)` such +/// that: +/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an +/// AffineDimExpr depending on whether the value is constant or not. +/// 2. the AffineExpr for %step is either an AffineConstantExpr or an +/// AffineSymbolExpr depending on whether the value is constant or not. +static void substituteLoop(unsigned dimIdx, Value lbVal, Value ubVal, + Value stepVal, SmallVectorImpl &exprs, + SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + MLIRContext *ctx = lbVal.getContext(); + + // 1. maybe add a new dim for the `lb`. + auto lbConstant = lbVal.getDefiningOp(); + AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx) + : getAffineDimExpr(dims.size(), ctx); + if (!lbConstant) + dims.push_back(lbVal); + + // 2. maybe add a new symbol for the `step`. + auto stepConstant = stepVal.getDefiningOp(); + AffineExpr step = stepConstant + ? getAffineConstantExpr(stepConstant.getValue(), ctx) + : getAffineSymbolExpr(symbols.size(), ctx); + if (!stepConstant) + symbols.push_back(stepVal); + + // 3. Rewrite `exprs` in place by replacing `dim[dimIdx]` by `lb + step * iv`. + AffineExpr iv = getAffineDimExpr(dimIdx, ctx); + for (auto &e : exprs) + e = e.replace(iv, lb + step * (iv - lb).floorDiv(step)); +} + +/// Traverse the `dims` and substitute linear expressions in place of induction +/// variables in `exprs`. +static void substitute(SmallVectorImpl &exprs, + SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + assert(!exprs.empty() && "Unexpected empty exprs"); + LLVM_DEBUG(llvm::interleaveComma(dims, DBGS() << "Start subst with dims: ")); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + // Note: `dims` and `symbols` grow as we iterate, upper bound is dynamic. + for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { + Value dim = dims[dimIdx]; + LLVM_DEBUG(DBGS() << "Subst: " << dim << "\n"); + + // Replace dim @ pos[dimIdx] by `%lb + %step * new_dim` + // Where new dim / symbols are added depending on whether the values are + // static or not. + if (auto forOp = scf::getForInductionVarOwner(dim)) { + substituteLoop(dimIdx, forOp.lowerBound(), forOp.upperBound(), + forOp.step(), exprs, dims, symbols); + continue; + } + if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) { + for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; ++idx) + substituteLoop(dimIdx, parallelForOp.lowerBound()[idx], + parallelForOp.upperBound()[idx], + parallelForOp.step()[idx], exprs, dims, symbols); + continue; + } + } + + // Cleanup and simplify the results. + SmallVector operands(dims.begin(), dims.end()); + operands.append(symbols.begin(), symbols.end()); + auto map = AffineMap::get(dims.size(), symbols.size(), exprs, + exprs.front().getContext()); + // Pull in affine.apply operations and compose them fully into the result. + fullyComposeAffineMapAndOperands(&map, &operands); + canonicalizeMapAndOperands(&map, &operands); + map = simplifyAffineMap(map); + // Assign the results. + exprs.assign(map.getResults().begin(), map.getResults().end()); + dims.assign(operands.begin(), operands.begin() + map.getNumDims()); + symbols.assign(operands.begin() + map.getNumDims(), operands.end()); +} + +LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( + AffineMinOp minOp, PatternRewriter &rewriter) const { + // At least one loop is needed to canonicalize affine.min + SCF. + auto isLoopLike = [](Value v) { + return scf::getParallelForInductionVarOwner(v) || + scf::getForInductionVarOwner(v); + }; + if (llvm::none_of(minOp.getDimOperands(), isLoopLike)) + return failure(); + + LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() + << "\n"); + + auto exprs = llvm::to_vector<4>(minOp.getAffineMap().getResults()); + SmallVector dims(minOp.getDimOperands()), + symbols(minOp.getSymbolOperands()); + substitute(exprs, dims, symbols); + + MLIRContext *ctx = minOp.getContext(); + auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx); + LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); + + // Check whether any of the expressions divides all expressions. In which case + // it is guaranteed to be the min. + for (auto e : map.getResults()) { + LLVM_DEBUG(DBGS() << "Candidate mod: " << e << "\n"); + if (!e.isSymbolicOrConstant()) + continue; + + LLVM_DEBUG(DBGS() << "Check whether mod: " << e << " is zero\n"); + SmallVector modExprs; + for (auto ee : map.getResults()) + modExprs.push_back(ee % e); + + AffineMap modMap = simplifyAffineMap( + AffineMap::get(map.getNumDims(), map.getNumSymbols(), modExprs, ctx)); + LLVM_DEBUG(DBGS() << "simplified modMap: " << modMap << "\n"); + + auto isZero = [](AffineExpr e) { + if (auto cst = e.dyn_cast()) + return cst.getValue() == 0; + return false; + }; + if (llvm::all_of(modMap.getResults(), isZero)) { + if (auto cst = e.dyn_cast()) { + rewriter.replaceOpWithNewOp(minOp, cst.getValue()); + } else { + auto resultMap = + AffineMap::get(map.getNumDims(), map.getNumSymbols(), {e}, ctx); + SmallVector resultOperands = dims; + resultOperands.append(symbols.begin(), symbols.end()); + canonicalizeMapAndOperands(&resultMap, &resultOperands); + resultMap = simplifyAffineMap(resultMap); + rewriter.replaceOpWithNewOp(minOp, resultMap, + resultOperands); + } + return success(); + } + } + + return failure(); +} diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -93,6 +93,38 @@ llvm_unreachable("Unknown AffineExpr"); } +/// Sparse replace method. Return the modified expression tree. +AffineExpr +AffineExpr::replace(const DenseMap &map) const { + auto it = map.find(*this); + if (it != map.end()) + return it->second; + switch (getKind()) { + default: + return *this; + case AffineExprKind::Add: + case AffineExprKind::Mul: + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: + case AffineExprKind::Mod: + auto binOp = cast(); + auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); + auto newLHS = lhs.replace(map); + auto newRHS = rhs.replace(map); + if (newLHS == lhs && newRHS == rhs) + return *this; + return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); + } + llvm_unreachable("Unknown AffineExpr"); +} + +/// Sparse replace method. Return the modified expression tree. +AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const { + DenseMap map; + map.insert(std::make_pair(expr, replacement)); + return replace(map); +} + /// Returns true if this expression is made out of only symbols and /// constants (no dimensional identifiers). bool AffineExpr::isSymbolicOrConstant() const { diff --git a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns | FileCheck %s + +// CHECK-LABEL: scf_for +func @scf_for(%A : memref, %step : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c7 = constant 7 : index + %c4 = constant 4 : index + + // CHECK: scf.for + // CHECK-NEXT: %[[C2:.*]] = constant 2 : index + // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] + // CHECK-NEXT: store %[[C2I64]], %{{.*}}[] : memref + scf.for %i = %c0 to %c4 step %c2 { + %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // CHECK: scf.for + // CHECK-NEXT: %[[C2:.*]] = constant 2 : index + // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] + // CHECK-NEXT: store %[[C2I64]], %{{.*}}[] : memref + scf.for %i = %c1 to %c7 step %c2 { + %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7] + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations. + // This affine map does not currently evaluate to 0: + // (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * 42) mod s0) + // Revisit when support is added. + // CHECK: scf.for + // CHECK: index_cast + %ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step) + scf.for %i = %c0 to %ub step %step { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations. + // This affine map does not currently evaluate to 0: + // (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * s0) mod s0) + // Revisit when support is added. + // CHECK: scf.for + // CHECK: index_cast + %ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step] + scf.for %i = %c0 to %ub2 step %step { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + return +} + +// CHECK-LABEL: scf_parallel +func @scf_parallel(%A : memref, %step : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c7 = constant 7 : index + %c4 = constant 4 : index + + // CHECK: scf.parallel + // CHECK-NEXT: %[[C2:.*]] = constant 2 : index + // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] + // CHECK-NEXT: store %[[C2I64]], %{{.*}}[] : memref + scf.parallel (%i) = (%c0) to (%c4) step (%c2) { + %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // CHECK: scf.parallel + // CHECK-NEXT: %[[C2:.*]] = constant 2 : index + // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] + // CHECK-NEXT: store %[[C2I64]], %{{.*}}[] : memref + scf.parallel (%i) = (%c1) to (%c7) step (%c2) { + %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7] + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations. + // This affine map does not currently evaluate to (0, 0): + // (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * 42) mod s0) + // TODO: Revisit when support is added. + // CHECK: scf.parallel + // CHECK: index_cast + %ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step) + scf.parallel (%i) = (%c0) to (%ub) step (%step) { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations. + // This affine map does not currently evaluate to (0, 0): + // (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * s0) mod s0) + // TODO: Revisit when support is added. + // CHECK: scf.parallel + // CHECK: index_cast + %ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step] + scf.parallel (%i) = (%c0) to (%ub2) step (%step) { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2) + %2 = index_cast %1: index to i64 + store %2, %A[]: memref + } + + return +} diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -54,6 +54,10 @@ llvm::cl::desc( "Test a fused pass that forwards linalg.copy to vector.transfer"), llvm::cl::init(false)}; + Option testAffineMinSCFCanonicalizationPatterns{ + *this, "test-affine-min-scf-canonicalization-patterns", + llvm::cl::desc("Test affine-min + scf canonicalization patterns."), + llvm::cl::init(false)}; }; } // end anonymous namespace @@ -300,6 +304,16 @@ applyPatternsAndFoldGreedily(funcOp, forwardPattern); } +static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { + OwningRewritePatternList foldPattern; + foldPattern.insert(funcOp.getContext()); + // Explicitly walk and apply the pattern locally to avoid more general folding + // on the rest of the IR. + funcOp.walk([&foldPattern](AffineMinOp minOp) { + applyOpPatternsAndFold(minOp, foldPattern); + }); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { auto lambda = [&](void *) { @@ -323,6 +337,8 @@ testMatmulToVectorPatterns2dTiling); if (testVectorTransferForwardingPatterns) return applyVectorTransferForwardingPatterns(getFunction()); + if (testAffineMinSCFCanonicalizationPatterns) + return applyAffineMinSCFCanonicalizationPatterns(getFunction()); } namespace mlir {