diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -210,6 +210,12 @@ ValueRange operands, bool eq, bool lower = true); + /// Returns the constraint system as an integer set. Returns a null integer + /// set if the system has no constraints, or if an integer set couldn't be + /// constructed as a result of a local variable's explicit representation not + /// being known and such a local variable appearing in any of the constraints. + IntegerSet getAsIntegerSet(MLIRContext *context) const; + /// Computes the lower and upper bounds of the first 'num' dimensional /// identifiers (starting at 'offset') as an affine map of the remaining /// identifiers (dimensional and symbolic). This method is able to detect @@ -484,7 +490,8 @@ /// that can be detected as redundant as a result of differing only in their /// constant term part. A constraint of the form >= 0 /// is considered trivially true. This method is a linear time method on the - /// constraints, does a single scan, and updates in place. + /// constraints, does a single scan, and updates in place. It also normalizes + /// constraints by their GCD and performs GCD tightening on inequalities. void removeTrivialRedundancy(); /// A more expensive check to detect redundant inequalities thatn diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -290,6 +290,12 @@ /// Returns true if `forOp' is a parallel loop. bool isLoopParallel(AffineForOp forOp); +/// Simplify the integer set by simplifying the underlying affine expressions by +/// flattening and some simple inference. Also, drop any duplicate constraints. +/// Returns the simplified integer set. This method runs in time linear in the +/// number of constraints. +IntegerSet simplifyIntegerSet(IntegerSet set); + } // end namespace mlir #endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -2191,13 +2191,16 @@ // Uses a DenseSet to hash and detect duplicates followed by a linear scan to // remove duplicates in place. void FlatAffineConstraints::removeTrivialRedundancy() { - SmallDenseSet, 8> rowSet; + GCDTightenInequalities(); + normalizeConstraintsByGCD(); // A map used to detect redundancy stemming from constraints that only differ // in their constant term. The value stored is // for a given row. SmallDenseMap, std::pair> rowsWithoutConstTerm; + // To unique rows. + SmallDenseSet, 8> rowSet; // Check if constraint is of the form >= 0. auto isTriviallyValid = [&](unsigned r) -> bool { @@ -2690,3 +2693,93 @@ return success(); } + +/// Compute an explicit representation for local vars. For all systems coming +/// from MLIR integer sets, maps, or expressions where local vars were +/// introduced to model floordivs and mods, this always succeeds. +static LogicalResult computeLocalVars(const FlatAffineConstraints &cst, + SmallVectorImpl &memo, + MLIRContext *context) { + unsigned numDims = cst.getNumDimIds(); + unsigned numSyms = cst.getNumSymbolIds(); + + // Initialize dimensional and symbolic identifiers. + for (unsigned i = 0; i < numDims; i++) + memo[i] = getAffineDimExpr(i, context); + for (unsigned i = numDims, e = numDims + numSyms; i < e; i++) + memo[i] = getAffineSymbolExpr(i - cst.getNumDimIds(), context); + + bool changed; + do { + changed = false; + for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i) + if (!memo[numDims + numSyms + i] && + detectAsFloorDiv(cst, numDims + numSyms + i, context, memo)) + changed = true; + } while (changed); + + ArrayRef localExprs = + ArrayRef(memo).take_back(cst.getNumLocalIds()); + return success( + llvm::all_of(localExprs, [](AffineExpr expr) { return expr; })); +} + +/// Returns true if the pos^th column is all zero for both inequalities and +/// equalities.. +static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) { + assert(pos < cst.getNumCols() && "position out of bounds"); + for (unsigned i = 0, e = cst.getNumInequalities(); i < e; ++i) + if (cst.atIneq(i, pos) != 0) + return false; + + for (unsigned i = 0, e = cst.getNumEqualities(); i < e; ++i) + if (cst.atEq(i, pos) != 0) + return false; + + return true; +} + +IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const { + if (getNumConstraints() == 0) + // Return universal set (always true): 0 == 0. + return IntegerSet::get(getNumDimIds(), getNumSymbolIds(), + getAffineConstantExpr(/*constant=*/0, context), + true); + + // Construct local references. + SmallVector memo(getNumIds(), AffineExpr()); + + if (failed(computeLocalVars(*this, memo, context))) { + // Check if the local variables without an explicit representation have + // zero coefficients everywhere. + for (unsigned i = getNumDimAndSymbolIds(), e = getNumIds(); i < e; ++i) { + if (!memo[i] && !isColZero(*this, /*pos=*/i)) { + LLVM_DEBUG(llvm::dbgs() << "one or more local exprs do not have an " + "explicit representation"); + return IntegerSet(); + } + } + } + + ArrayRef localExprs = + ArrayRef(memo).take_back(getNumLocalIds()); + + // Construct the IntegerSet from the equalities/inequalities. + unsigned numDims = getNumDimIds(); + unsigned numSyms = getNumSymbolIds(); + + SmallVector eqFlags(getNumConstraints()); + std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true); + std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false); + + SmallVector exprs; + exprs.reserve(getNumConstraints()); + + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) + exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms, + localExprs, context)); + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) + exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims, + numSyms, localExprs, context)); + return IntegerSet::get(numDims, numSyms, exprs, eqFlags); +} diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/IntegerSet.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -1007,3 +1008,17 @@ } return true; } + +IntegerSet mlir::simplifyIntegerSet(IntegerSet set) { + FlatAffineConstraints fac(set); + MLIRContext *context = set.getContext(); + if (fac.isEmpty()) + return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(), + context); + fac.removeTrivialRedundancy(); + + auto simplifiedSet = fac.getAsIntegerSet(context); + assert(simplifiedSet && "guaranteed to succeed while roundtripping"); + + return simplifiedSet; +} diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -11,12 +11,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineStructures.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/Utils.h" #define DEBUG_TYPE "simplify-affine-structure" @@ -58,15 +59,7 @@ op->setAttr(name, simplified); } - /// Performs basic integer set simplifications. Checks if it's empty, and - /// replaces it with the canonical empty set if it is. - IntegerSet simplify(IntegerSet set) { - FlatAffineConstraints fac(set); - if (fac.isEmpty()) - return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(), - &getContext()); - return set; - } + IntegerSet simplify(IntegerSet set) { return simplifyIntegerSet(set); } /// Performs basic affine map simplifications. AffineMap simplify(AffineMap map) { diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -simplify-affine-structures | FileCheck %s +// RUN: mlir-opt %s -split-input-file -simplify-affine-structures | FileCheck %s // CHECK-DAG: [[SET_EMPTY_2D:#set[0-9]+]] = affine_set<(d0, d1) : (1 == 0)> -// CHECK-DAG: #set1 = affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)> +// CHECK-DAG: #set1 = affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0)> // CHECK-DAG: #set2 = affine_set<(d0, d1)[s0, s1] : (1 == 0)> // CHECK-DAG: #set3 = affine_set<(d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)> // CHECK-DAG: [[SET_EMPTY_1D:#set[0-9]+]] = affine_set<(d0) : (1 == 0)> @@ -236,3 +236,23 @@ return } + +// ----- + +// CHECK-DAG: #[[SET1:.*]] = affine_set<(d0, d1) : (d0 >= 0, -d0 + 50 >= 0) +// CHECK-DAG: #[[SET2:.*]] = affine_set<(d0, d1) : (1 == 0) +// CHECK-DAG: #[[SET3:.*]] = affine_set<(d0, d1) : (0 == 0) + +// CHECK-LABEL: func @simplify_set +func @simplify_set(%a : index, %b : index) { + // CHECK: affine.if #[[SET1]] + affine.if affine_set<(d0, d1) : (d0 - d1 + d1 + d0 >= 0, 2 >= 0, d0 >= 0, -d0 + 50 >= 0, -d0 + 100 >= 0)>(%a, %b) { + } + // CHECK: affine.if #[[SET2]] + affine.if affine_set<(d0, d1) : (d0 mod 2 - 1 == 0, d0 - 2 * (d0 floordiv 2) == 0)>(%a, %b) { + } + // CHECK: affine.if #[[SET3]] + affine.if affine_set<(d0, d1) : (1 >= 0, 3 >= 0)>(%a, %b) { + } + return +}