diff --git a/mlir/include/mlir/Dialect/SDBM/SDBM.h b/mlir/include/mlir/Dialect/SDBM/SDBM.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBM.h +++ /dev/null @@ -1,197 +0,0 @@ -//===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined -// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_SDBM_SDBM_H -#define MLIR_DIALECT_SDBM_SDBM_H - -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseMap.h" - -namespace mlir { - -class MLIRContext; -class SDBMDialect; -class SDBMExpr; -class SDBMTermExpr; - -/// A utility class for SDBM to represent an integer with potentially infinite -/// positive value. This uses the largest value of int64_t to represent infinity -/// and redefines the arithmetic operators so that the infinity "saturates": -/// inf + x = inf, -/// inf - x = inf. -/// If a sum of two finite values reaches the largest value of int64_t, the -/// behavior of IntInfty is undefined (in practice, it asserts), similarly to -/// regular signed integer overflow. -class IntInfty { -public: - constexpr static int64_t infty = std::numeric_limits::max(); - - /*implicit*/ IntInfty(int64_t v) : value(v) {} - - IntInfty &operator=(int64_t v) { - value = v; - return *this; - } - - static IntInfty infinity() { return IntInfty(infty); } - - int64_t getValue() const { return value; } - explicit operator int64_t() const { return value; } - - bool isFinite() { return value != infty; } - -private: - int64_t value; -}; - -inline IntInfty operator+(IntInfty lhs, IntInfty rhs) { - if (!lhs.isFinite() || !rhs.isFinite()) - return IntInfty::infty; - - // Check for overflows, treating the sum of two values adding up to INT_MAX as - // overflow. Convert values to unsigned to get an extra bit and avoid the - // undefined behavior of signed integer overflows. - assert((lhs.getValue() <= 0 || rhs.getValue() <= 0 || - static_cast(lhs.getValue()) + - static_cast(rhs.getValue()) < - static_cast(std::numeric_limits::max())) && - "IntInfty overflow"); - // Check for underflows by converting values to unsigned to avoid undefined - // behavior of signed integers perform the addition (bitwise result is same - // because numbers are required to be two's complement in C++) and check if - // the sign bit remains negative. - assert((lhs.getValue() >= 0 || rhs.getValue() >= 0 || - ((static_cast(lhs.getValue()) + - static_cast(rhs.getValue())) >> - 63) == 1) && - "IntInfty underflow"); - - return lhs.getValue() + rhs.getValue(); -} - -inline bool operator<(IntInfty lhs, IntInfty rhs) { - return lhs.getValue() < rhs.getValue(); -} - -inline bool operator<=(IntInfty lhs, IntInfty rhs) { - return lhs.getValue() <= rhs.getValue(); -} - -inline bool operator==(IntInfty lhs, IntInfty rhs) { - return lhs.getValue() == rhs.getValue(); -} - -inline bool operator!=(IntInfty lhs, IntInfty rhs) { return !(lhs == rhs); } - -/// Striped difference-bound matrix is a representation of an integer set bound -/// by a system of SDBMExprs interpreted as inequalities "expr <= 0". -class SDBM { -public: - /// Obtain an SDBM from a list of SDBM expressions treated as inequalities and - /// equalities with zero. - static SDBM get(ArrayRef inequalities, - ArrayRef equalities); - - void getSDBMExpressions(SDBMDialect *dialect, - SmallVectorImpl &inequalities, - SmallVectorImpl &equalities); - - void print(raw_ostream &os); - void dump(); - - IntInfty operator()(int i, int j) { return at(i, j); } - -private: - /// Get the given element of the difference bounds matrix. First index - /// corresponds to the negative term of the difference, second index - /// corresponds to the positive term of the difference. - IntInfty &at(int i, int j) { return matrix[i * getNumVariables() + j]; } - - /// Populate `inequalities` and `equalities` based on the values at(row,col) - /// and at(col,row) of the DBM. Depending on the values being finite and - /// being subsumed by stripe expressions, this may or may not add elements to - /// the lists of equalities and inequalities. - void convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr, - SDBMTermExpr colExpr, - SmallVectorImpl &inequalities, - SmallVectorImpl &equalities); - - /// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only - /// adds new inequalities if the inequality is not trivially true. - void convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr, - SmallVectorImpl &inequalities); - - /// Get the total number of elements in the matrix. - unsigned getNumVariables() const { - return 1 + numDims + numSymbols + numTemporaries; - } - - /// Get the position in the matrix that corresponds to the given dimension. - unsigned getDimPosition(unsigned position) const { return 1 + position; } - - /// Get the position in the matrix that corresponds to the given symbol. - unsigned getSymbolPosition(unsigned position) const { - return 1 + numDims + position; - } - - /// Get the position in the matrix that corresponds to the given temporary. - unsigned getTemporaryPosition(unsigned position) const { - return 1 + numDims + numSymbols + position; - } - - /// Number of dimensions in the system, - unsigned numDims; - /// Number of symbols in the system. - unsigned numSymbols; - /// Number of temporary variables in the system. - unsigned numTemporaries; - - /// Difference bounds matrix, stored as a linearized row-major vector. - /// Each value in this matrix corresponds to an inequality - /// - /// v@col - v@row <= at(row, col) - /// - /// where v@col and v@row are the variables that correspond to the linearized - /// position in the matrix. The positions correspond to - /// - /// - constant 0 (producing constraints v@col <= X and -v@row <= Y); - /// - SDBM expression dimensions (d0, d1, ...); - /// - SDBM expression symbols (s0, s1, ...); - /// - temporary variables (t0, t1, ...). - /// - /// Temporary variables are introduced to represent expressions that are not - /// trivially a difference between two variables. For example, if one side of - /// a difference expression is itself a stripe expression, it will be replaced - /// with a temporary variable assigned equal to this expression. - /// - /// Infinite entries in the matrix correspond correspond to an absence of a - /// constraint: - /// - /// v@col - v@row <= infinity - /// - /// is trivially true. Negated values at symmetric positions in the matrix - /// allow one to couple two inequalities into a single equality. - std::vector matrix; - - /// The mapping between the indices of variables in the DBM and the stripe - /// expressions they are equal to. These expressions are stored as they - /// appeared when constructing an SDBM from a SDBMExprs, in particular no - /// temporaries can appear in these expressions. This removes the need to - /// iteratively substitute definitions of the temporaries in the reverse - /// conversion. - DenseMap stripeToPoint; -}; - -} // namespace mlir - -#endif // MLIR_DIALECT_SDBM_SDBM_H diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h +++ /dev/null @@ -1,37 +0,0 @@ -//===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_SDBM_SDBMDIALECT_H -#define MLIR_DIALECT_SDBM_SDBMDIALECT_H - -#include "mlir/IR/Dialect.h" -#include "mlir/Support/StorageUniquer.h" - -namespace mlir { -class MLIRContext; - -class SDBMDialect : public Dialect { -public: - SDBMDialect(MLIRContext *context); - - /// Since there are no other virtual methods in this derived class, override - /// the destructor so that key methods get defined in the corresponding - /// module. - ~SDBMDialect() override; - - static StringRef getDialectNamespace() { return "sdbm"; } - - /// Get the uniquer for SDBM expressions. This should not be used directly. - StorageUniquer &getUniquer() { return uniquer; } - -private: - StorageUniquer uniquer; -}; -} // namespace mlir - -#endif // MLIR_DIALECT_SDBM_SDBMDIALECT_H diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h +++ /dev/null @@ -1,576 +0,0 @@ -//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// A striped difference-bound matrix (SDBM) expression is a constant expression, -// an identifier, a binary expression with constant RHS and +, stripe operators -// or a difference expression between two identifiers. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_SDBM_SDBMEXPR_H -#define MLIR_DIALECT_SDBM_SDBMEXPR_H - -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseMapInfo.h" - -namespace mlir { - -class AffineExpr; -class MLIRContext; - -enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg }; - -namespace detail { -struct SDBMExprStorage; -struct SDBMBinaryExprStorage; -struct SDBMDiffExprStorage; -struct SDBMTermExprStorage; -struct SDBMConstantExprStorage; -struct SDBMNegExprStorage; -} // namespace detail - -class SDBMConstantExpr; -class SDBMDialect; -class SDBMDimExpr; -class SDBMSymbolExpr; -class SDBMTermExpr; - -/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side -/// expression for the SDBM framework. SDBM expressions are a subset of affine -/// expressions supporting low-complexity algorithms for the operations used in -/// loop transformations. In particular, are supported: -/// - constant expressions; -/// - single variables (dimensions and symbols) with +1 or -1 coefficient; -/// - stripe expressions: "x # C", where "x" is a single variable or another -/// stripe expression, "#" is the stripe operator, and "C" is a constant -/// expression; "#" is defined as x - x mod C. -/// - sum expressions between single variable/stripe expressions and constant -/// expressions; -/// - difference expressions between single variable/stripe expressions. -/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing -/// and operating on SDBM expressions. For example, it requires the LHS of a -/// sum expression to be a single variable or a stripe expression. These -/// restrictions are intended to force the caller to perform the necessary -/// simplifications to stay within the SDBM domain, because SDBM expressions do -/// not combine in more cases than they do. This choice may be reconsidered in -/// the future. -/// -/// SDBM expressions are grouped into the following structure -/// - expression -/// - varying -/// - direct -/// - sum <- (term, constant) -/// - term -/// - symbol -/// - dimension -/// - stripe <- (direct, constant) -/// - negation <- (direct) -/// - difference <- (direct, term) -/// - constant -/// The notation <- (...) denotes the types of subexpressions a compound -/// expression can combine. The tree of subexpressions essentially imposes the -/// following canonicalization rules: -/// - constants are always folded; -/// - constants can only appear on the RHS of an expression; -/// - double negation must be elided; -/// - an additive constant term is only allowed in a sum expression, and -/// should be sunk into the nearest such expression in the tree; -/// - zero constant expression can only appear at the top level. -/// -/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by -/// an MLIRContext, and should be used by-value. They are uniqued in the -/// MLIRContext and immortal. -class SDBMExpr { -public: - using ImplType = detail::SDBMExprStorage; - SDBMExpr() : impl(nullptr) {} - /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {} - - /// SDBM expressions are thin wrappers around a unique'ed immutable pointer, - /// which makes them trivially assignable and trivially copyable. - SDBMExpr(const SDBMExpr &) = default; - SDBMExpr &operator=(const SDBMExpr &) = default; - - /// SDBM expressions can be compared straight-forwardly. - bool operator==(const SDBMExpr &other) const { return impl == other.impl; } - bool operator!=(const SDBMExpr &other) const { return !(*this == other); } - - /// SDBM expressions are convertible to `bool`: null expressions are converted - /// to false, non-null expressions are converted to true. - explicit operator bool() const { return impl != nullptr; } - bool operator!() const { return !static_cast(*this); } - - /// Negate the given SDBM expression. - SDBMExpr operator-(); - - /// Prints the SDBM expression. - void print(raw_ostream &os) const; - void dump() const; - - /// LLVM-style casts. - template bool isa() const { return U::isClassFor(*this); } - template U dyn_cast() const { - if (!isa()) - return {}; - return U(const_cast(this)->impl); - } - template U cast() const { - assert(isa() && "cast to incorrect subtype"); - return U(const_cast(this)->impl); - } - - /// Support for LLVM hashing. - ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); } - - /// Returns the kind of the SDBM expression. - SDBMExprKind getKind() const; - - /// Returns the MLIR context in which this expression lives. - MLIRContext *getContext() const; - - /// Returns the SDBM dialect instance. - SDBMDialect *getDialect() const; - - /// Convert the SDBM expression into an Affine expression. This always - /// succeeds because SDBM are a subset of affine. - AffineExpr getAsAffineExpr() const; - - /// Try constructing an SDBM expression from the given affine expression. - /// This may fail if the affine expression is not representable as SDBM, in - /// which case llvm::None is returned. The conversion procedure recognizes - /// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B) - /// patterns for the stripe expression. - static Optional tryConvertAffineExpr(AffineExpr affine); - -protected: - ImplType *impl; -}; - -/// SDBM constant expression, wraps a 64-bit integer. -class SDBMConstantExpr : public SDBMExpr { -public: - using ImplType = detail::SDBMConstantExprStorage; - - using SDBMExpr::SDBMExpr; - - /// Obtain or create a constant expression unique'ed in the given dialect - /// (which belongs to a context). - static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::Constant; - } - - int64_t getValue() const; -}; - -/// SDBM varying expression can be one of: -/// - input variable expression; -/// - stripe expression; -/// - negation (product with -1) of either of the above. -/// - sum of a varying and a constant expression -/// - difference between varying expressions -class SDBMVaryingExpr : public SDBMExpr { -public: - using ImplType = detail::SDBMExprStorage; - using SDBMExpr::SDBMExpr; - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::DimId || - expr.getKind() == SDBMExprKind::SymbolId || - expr.getKind() == SDBMExprKind::Neg || - expr.getKind() == SDBMExprKind::Stripe || - expr.getKind() == SDBMExprKind::Add || - expr.getKind() == SDBMExprKind::Diff; - } -}; - -/// SDBM direct expression includes exactly one variable (symbol or dimension), -/// which is not negated in the expression. It can be one of: -/// - term expression; -/// - sum expression. -class SDBMDirectExpr : public SDBMVaryingExpr { -public: - using SDBMVaryingExpr::SDBMVaryingExpr; - - /// If this is a sum expression, return its variable part, otherwise return - /// self. - SDBMTermExpr getTerm(); - - /// If this is a sum expression, return its constant part, otherwise return 0. - int64_t getConstant(); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::DimId || - expr.getKind() == SDBMExprKind::SymbolId || - expr.getKind() == SDBMExprKind::Stripe || - expr.getKind() == SDBMExprKind::Add; - } -}; - -/// SDBM term expression can be one of: -/// - single variable expression; -/// - stripe expression. -/// Stripe expressions are treated as terms since, in the SDBM domain, they are -/// attached to temporary variables and can appear anywhere a variable can. -class SDBMTermExpr : public SDBMDirectExpr { -public: - using SDBMDirectExpr::SDBMDirectExpr; - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::DimId || - expr.getKind() == SDBMExprKind::SymbolId || - expr.getKind() == SDBMExprKind::Stripe; - } -}; - -/// SDBM sum expression. LHS is a term expression and RHS is a constant. -class SDBMSumExpr : public SDBMDirectExpr { -public: - using ImplType = detail::SDBMBinaryExprStorage; - using SDBMDirectExpr::SDBMDirectExpr; - - /// Obtain or create a sum expression unique'ed in the given context. - static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs); - - static bool isClassFor(const SDBMExpr &expr) { - SDBMExprKind kind = expr.getKind(); - return kind == SDBMExprKind::Add; - } - - SDBMTermExpr getLHS() const; - SDBMConstantExpr getRHS() const; -}; - -/// SDBM difference expression. LHS is a direct expression, i.e. it may be a -/// sum of a term and a constant. RHS is a term expression. Thus the -/// expression (t1 - t2 + C) with term expressions t1,t2 is represented as -/// diff(sum(t1, C), t2) -/// and it is possible to extract the constant factor without negating it. -class SDBMDiffExpr : public SDBMVaryingExpr { -public: - using ImplType = detail::SDBMDiffExprStorage; - using SDBMVaryingExpr::SDBMVaryingExpr; - - /// Obtain or create a difference expression unique'ed in the given context. - static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::Diff; - } - - SDBMDirectExpr getLHS() const; - SDBMTermExpr getRHS() const; -}; - -/// SDBM stripe expression "x # C" where "x" is a term expression, "C" is a -/// constant expression and "#" is the stripe operator defined as: -/// x # C = x - x mod C. -class SDBMStripeExpr : public SDBMTermExpr { -public: - using ImplType = detail::SDBMBinaryExprStorage; - using SDBMTermExpr::SDBMTermExpr; - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::Stripe; - } - - static SDBMStripeExpr get(SDBMDirectExpr var, SDBMConstantExpr stripeFactor); - - SDBMDirectExpr getLHS() const; - SDBMConstantExpr getStripeFactor() const; -}; - -/// SDBM "input" variable expression can be either a dimension identifier or -/// a symbol identifier. When used to define SDBM functions, dimensions are -/// interpreted as function arguments while symbols are treated as unknown but -/// constant values, hence the name. -class SDBMInputExpr : public SDBMTermExpr { -public: - using ImplType = detail::SDBMTermExprStorage; - using SDBMTermExpr::SDBMTermExpr; - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::DimId || - expr.getKind() == SDBMExprKind::SymbolId; - } - - unsigned getPosition() const; -}; - -/// SDBM dimension expression. Dimensions correspond to function arguments -/// when defining functions using SDBM expressions. -class SDBMDimExpr : public SDBMInputExpr { -public: - using ImplType = detail::SDBMTermExprStorage; - using SDBMInputExpr::SDBMInputExpr; - - /// Obtain or create a dimension expression unique'ed in the given dialect - /// (which belongs to a context). - static SDBMDimExpr get(SDBMDialect *dialect, unsigned position); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::DimId; - } -}; - -/// SDBM symbol expression. Symbols correspond to symbolic constants when -/// defining functions using SDBM expressions. -class SDBMSymbolExpr : public SDBMInputExpr { -public: - using ImplType = detail::SDBMTermExprStorage; - using SDBMInputExpr::SDBMInputExpr; - - /// Obtain or create a symbol expression unique'ed in the given dialect (which - /// belongs to a context). - static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::SymbolId; - } -}; - -/// Negation of an SDBM variable expression. Equivalent to multiplying the -/// expression with -1 (SDBM does not support other coefficients that 1 and -1). -class SDBMNegExpr : public SDBMVaryingExpr { -public: - using ImplType = detail::SDBMNegExprStorage; - using SDBMVaryingExpr::SDBMVaryingExpr; - - /// Obtain or create a negation expression unique'ed in the given context. - static SDBMNegExpr get(SDBMDirectExpr var); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::Neg; - } - - SDBMDirectExpr getVar() const; -}; - -/// A visitor class for SDBM expressions. Calls the kind-specific function -/// depending on the kind of expression it visits. -template class SDBMVisitor { -public: - /// Visit the given SDBM expression, dispatching to kind-specific functions. - Result visit(SDBMExpr expr) { - auto *derived = static_cast(this); - switch (expr.getKind()) { - case SDBMExprKind::Add: - case SDBMExprKind::Diff: - case SDBMExprKind::DimId: - case SDBMExprKind::SymbolId: - case SDBMExprKind::Neg: - case SDBMExprKind::Stripe: - return derived->visitVarying(expr.cast()); - case SDBMExprKind::Constant: - return derived->visitConstant(expr.cast()); - } - - llvm_unreachable("unsupported SDBM expression kind"); - } - - /// Traverse the SDBM expression tree calling `visit` on each node - /// in depth-first preorder. - void walkPreorder(SDBMExpr expr) { return walk(expr); } - - /// Traverse the SDBM expression tree calling `visit` on each node in - /// depth-first postorder. - void walkPostorder(SDBMExpr expr) { return walk(expr); } - -protected: - /// Default visitors do nothing. - void visitSum(SDBMSumExpr) {} - void visitDiff(SDBMDiffExpr) {} - void visitStripe(SDBMStripeExpr) {} - void visitDim(SDBMDimExpr) {} - void visitSymbol(SDBMSymbolExpr) {} - void visitNeg(SDBMNegExpr) {} - void visitConstant(SDBMConstantExpr) {} - - /// Default implementation of visitDirect dispatches to the dedicated for sums - /// or delegates to visitTerm for the other expression kinds. Concrete - /// visitors can overload it. - Result visitDirect(SDBMDirectExpr expr) { - auto *derived = static_cast(this); - if (auto sum = expr.dyn_cast()) - return derived->visitSum(sum); - else - return derived->visitTerm(expr.cast()); - } - - /// Default implementation of visitTerm dispatches to the special functions - /// for stripes and other variables. Concrete visitors can override it. - Result visitTerm(SDBMTermExpr expr) { - auto *derived = static_cast(this); - if (expr.getKind() == SDBMExprKind::Stripe) - return derived->visitStripe(expr.cast()); - else - return derived->visitInput(expr.cast()); - } - - /// Default implementation of visitInput dispatches to the special - /// functions for dimensions or symbols. Concrete visitors can override it to - /// visit all variables instead. - Result visitInput(SDBMInputExpr expr) { - auto *derived = static_cast(this); - if (expr.getKind() == SDBMExprKind::DimId) - return derived->visitDim(expr.cast()); - else - return derived->visitSymbol(expr.cast()); - } - - /// Default implementation of visitVarying dispatches to the special - /// functions for variables and negations thereof. Concrete visitors can - /// override it to visit all variables and negations instead. - Result visitVarying(SDBMVaryingExpr expr) { - auto *derived = static_cast(this); - if (auto var = expr.dyn_cast()) - return derived->visitDirect(var); - else if (auto neg = expr.dyn_cast()) - return derived->visitNeg(neg); - else if (auto diff = expr.dyn_cast()) - return derived->visitDiff(diff); - - llvm_unreachable("unhandled subtype of varying SDBM expression"); - } - - template void walk(SDBMExpr expr) { - if (isPreorder) - visit(expr); - if (auto sumExpr = expr.dyn_cast()) { - walk(sumExpr.getLHS()); - walk(sumExpr.getRHS()); - } else if (auto diffExpr = expr.dyn_cast()) { - walk(diffExpr.getLHS()); - walk(diffExpr.getRHS()); - } else if (auto stripeExpr = expr.dyn_cast()) { - walk(stripeExpr.getLHS()); - walk(stripeExpr.getStripeFactor()); - } else if (auto negExpr = expr.dyn_cast()) { - walk(negExpr.getVar()); - } - if (!isPreorder) - visit(expr); - } -}; - -/// Overloaded arithmetic operators for SDBM expressions asserting that their -/// arguments have the proper SDBM expression subtype. Perform canonicalization -/// and constant folding on these expressions. -namespace ops_assertions { - -/// Add two SDBM expressions. At least one of the expressions must be a -/// constant or a negation, but both expressions cannot be negations -/// simultaneously. -SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs); -inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) { - return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs); -} -inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) { - return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs; -} - -/// Subtract an SDBM expression from another SDBM expression. Both expressions -/// must not be difference expressions. -SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs); -inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) { - return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs); -} -inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) { - return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs; -} - -/// Construct a stripe expression from a positive expression and a positive -/// constant stripe factor. -SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor); -inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) { - return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor)); -} -} // namespace ops_assertions - -} // end namespace mlir - -namespace llvm { -// SDBMExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMExpr getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMExpr(static_cast(pointer)); - } - static mlir::SDBMExpr getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMExpr(static_cast(pointer)); - } - static unsigned getHashValue(mlir::SDBMExpr expr) { - return expr.hash_value(); - } - static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) { - return lhs == rhs; - } -}; - -// SDBMDirectExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMDirectExpr getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMDirectExpr( - static_cast(pointer)); - } - static mlir::SDBMDirectExpr getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMDirectExpr( - static_cast(pointer)); - } - static unsigned getHashValue(mlir::SDBMDirectExpr expr) { - return expr.hash_value(); - } - static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs) { - return lhs == rhs; - } -}; - -// SDBMTermExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMTermExpr getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMTermExpr(static_cast(pointer)); - } - static mlir::SDBMTermExpr getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMTermExpr(static_cast(pointer)); - } - static unsigned getHashValue(mlir::SDBMTermExpr expr) { - return expr.hash_value(); - } - static bool isEqual(mlir::SDBMTermExpr lhs, mlir::SDBMTermExpr rhs) { - return lhs == rhs; - } -}; - -// SDBMConstantExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMConstantExpr getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMConstantExpr( - static_cast(pointer)); - } - static mlir::SDBMConstantExpr getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMConstantExpr( - static_cast(pointer)); - } - static unsigned getHashValue(mlir::SDBMConstantExpr expr) { - return expr.hash_value(); - } - static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) { - return lhs == rhs; - } -}; -} // namespace llvm - -#endif // MLIR_DIALECT_SDBM_SDBMEXPR_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -35,7 +35,6 @@ #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/SDBM/SDBMDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -75,7 +74,6 @@ vector::VectorDialect, NVVM::NVVMDialect, ROCDL::ROCDLDialect, - SDBMDialect, shape::ShapeDialect, sparse_tensor::SparseTensorDialect, tensor::TensorDialect, diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -17,7 +17,6 @@ add_subdirectory(PDLInterp) add_subdirectory(Quant) add_subdirectory(SCF) -add_subdirectory(SDBM) add_subdirectory(Shape) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/lib/Dialect/SDBM/CMakeLists.txt b/mlir/lib/Dialect/SDBM/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Dialect/SDBM/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_mlir_dialect_library(MLIRSDBM - SDBM.cpp - SDBMDialect.cpp - SDBMExpr.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SDBM - - LINK_LIBS PUBLIC - MLIRIR - ) diff --git a/mlir/lib/Dialect/SDBM/SDBM.cpp b/mlir/lib/Dialect/SDBM/SDBM.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/SDBM/SDBM.cpp +++ /dev/null @@ -1,551 +0,0 @@ -//===- SDBM.cpp - MLIR SDBM implementation --------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined -// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/SDBM/SDBM.h" -#include "mlir/Dialect/SDBM/SDBMExpr.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; - -// Helper function for SDBM construction that collects information necessary to -// start building an SDBM in one sweep. In particular, it records the largest -// position of a dimension in `dim`, that of a symbol in `symbol` as well as -// collects all unique stripe expressions in `stripes`. Uses SetVector to -// ensure these expressions always have the same order. -static void collectSDBMBuildInfo(SDBMExpr expr, int &dim, int &symbol, - llvm::SmallSetVector &stripes) { - struct Visitor : public SDBMVisitor { - void visitDim(SDBMDimExpr dimExpr) { - int p = dimExpr.getPosition(); - if (p > maxDimPosition) - maxDimPosition = p; - } - void visitSymbol(SDBMSymbolExpr symbExpr) { - int p = symbExpr.getPosition(); - if (p > maxSymbPosition) - maxSymbPosition = p; - } - void visitStripe(SDBMStripeExpr stripeExpr) { stripes.insert(stripeExpr); } - - Visitor(llvm::SmallSetVector &stripes) : stripes(stripes) {} - - int maxDimPosition = -1; - int maxSymbPosition = -1; - llvm::SmallSetVector &stripes; - }; - - Visitor visitor(stripes); - visitor.walkPostorder(expr); - dim = std::max(dim, visitor.maxDimPosition); - symbol = std::max(symbol, visitor.maxSymbPosition); -} - -namespace { -// Utility class for SDBMBuilder. Represents a value that can be inserted in -// the SDB matrix that corresponds to "v0 - v1 + C <= 0", where v0 and v1 is -// any combination of the positive and negative positions. Since multiple -// variables can be declared equal to the same stripe expression, the -// constraints on this expression must be reflected to all these variables. For -// example, if -// d0 = s0 # 42 -// d1 = s0 # 42 -// d2 = s1 # 2 -// d3 = s1 # 2 -// the constraint -// s0 # 42 - s1 # 2 <= C -// should be reflected in the DB matrix as -// d0 - d2 <= C -// d1 - d2 <= C -// d0 - d3 <= C -// d1 - d3 <= C -// since the DB matrix has no knowledge of the transitive equality between d0, -// d1 and s0 # 42 as well as between d2, d3 and s1 # 2. This knowledge can be -// obtained by computing a transitive closure, which is impossible until the -// DBM is actually built. -struct SDBMBuilderResult { - // Positions in the matrix of the variables taken with the "+" sign in the - // difference expression, 0 if it is a constant rather than a variable. - SmallVector positivePos; - - // Positions in the matrix of the variables taken with the "-" sign in the - // difference expression, 0 if it is a constant rather than a variable. - SmallVector negativePos; - - // Constant value in the difference expression. - int64_t value = 0; -}; - -// Visitor for building an SDBM from SDBM expressions. After traversing an SDBM -// expression, produces an update to the SDB matrix specifying the positions in -// the matrix and the negated value that should be stored. Both the positive -// and the negative positions may be lists of indices in cases where multiple -// variables are equal to the same stripe expression. In such cases, the update -// applies to the cross product of positions because elements involved in the -// update are (transitively) equal and should have the same constraints, but we -// may not have an explicit equality for them. -struct SDBMBuilder : public SDBMVisitor { -public: - // A difference expression produces both the positive and the negative - // coordinate in the matrix, recursively traversing the LHS and the RHS. The - // value is the difference between values obtained from LHS and RHS. - SDBMBuilderResult visitDiff(SDBMDiffExpr diffExpr) { - auto lhs = visit(diffExpr.getLHS()); - auto rhs = visit(diffExpr.getRHS()); - assert(lhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 && - "unexpected negative expression in a difference expression"); - assert(rhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 && - "unexpected negative expression in a difference expression"); - - SDBMBuilderResult result; - result.positivePos = lhs.positivePos; - result.negativePos = rhs.positivePos; - result.value = lhs.value - rhs.value; - return result; - } - - // An input expression is always taken with the "+" sign and therefore - // produces a positive coordinate keeping the negative coordinate zero for an - // eventual constant. - SDBMBuilderResult visitInput(SDBMInputExpr expr) { - SDBMBuilderResult r; - r.positivePos.push_back(linearPosition(expr)); - r.negativePos.push_back(0); - return r; - } - - // A stripe expression is always equal to one or more variables, which may be - // temporaries, and appears with a "+" sign in the SDBM expression tree. Take - // the positions of the corresponding variables as positive coordinates. - SDBMBuilderResult visitStripe(SDBMStripeExpr expr) { - SDBMBuilderResult r; - assert(pointExprToStripe.count(expr)); - r.positivePos = pointExprToStripe[expr]; - r.negativePos.push_back(0); - return r; - } - - // A constant expression has both coordinates at zero. - SDBMBuilderResult visitConstant(SDBMConstantExpr expr) { - SDBMBuilderResult r; - r.positivePos.push_back(0); - r.negativePos.push_back(0); - r.value = expr.getValue(); - return r; - } - - // A negation expression swaps the positive and the negative coordinates - // and also negates the constant value. - SDBMBuilderResult visitNeg(SDBMNegExpr expr) { - SDBMBuilderResult result = visit(expr.getVar()); - std::swap(result.positivePos, result.negativePos); - result.value = -result.value; - return result; - } - - // The RHS of a sum expression must be a constant and therefore must have both - // positive and negative coordinates at zero. Take the sum of the values - // between LHS and RHS and keep LHS coordinates. - SDBMBuilderResult visitSum(SDBMSumExpr expr) { - auto lhs = visit(expr.getLHS()); - auto rhs = visit(expr.getRHS()); - for (auto pos : rhs.negativePos) { - (void)pos; - assert(pos == 0 && "unexpected variable on the RHS of SDBM sum"); - } - for (auto pos : rhs.positivePos) { - (void)pos; - assert(pos == 0 && "unexpected variable on the RHS of SDBM sum"); - } - - lhs.value += rhs.value; - return lhs; - } - - SDBMBuilder(DenseMap> &pointExprToStripe, - function_ref callback) - : pointExprToStripe(pointExprToStripe), linearPosition(callback) {} - - DenseMap> &pointExprToStripe; - function_ref linearPosition; -}; -} // namespace - -SDBM SDBM::get(ArrayRef inequalities, ArrayRef equalities) { - SDBM result; - - // TODO: consider detecting equalities in the list of inequalities. - // This is potentially expensive and requires to - // - create a list of negated inequalities (may allocate under lock); - // - perform a pairwise comparison of direct and negated inequalities; - // - copy the lists of equalities and inequalities, and move entries between - // them; - // only for the purpose of sparing a temporary variable in cases where an - // implicit equality between a variable and a stripe expression is present in - // the input. - - // Do the first sweep over (in)equalities to collect the information necessary - // to allocate the SDB matrix (number of dimensions, symbol and temporary - // variables required for stripe expressions). - llvm::SmallSetVector stripes; - int maxDim = -1; - int maxSymbol = -1; - for (auto expr : inequalities) - collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes); - for (auto expr : equalities) - collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes); - // Indexing of dimensions starts with 0, obtain the number of dimensions by - // incrementing the maximal position of the dimension seen in expressions. - result.numDims = maxDim + 1; - result.numSymbols = maxSymbol + 1; - result.numTemporaries = 0; - - // Helper function that returns the position of the variable represented by - // an SDBM input expression. - auto linearPosition = [result](SDBMInputExpr expr) { - if (expr.isa()) - return result.getDimPosition(expr.getPosition()); - return result.getSymbolPosition(expr.getPosition()); - }; - - // Check if some stripe expressions are equal to another variable. In - // particular, look for the equalities of the form - // d0 - stripe-expression = 0, or - // stripe-expression - d0 = 0. - // There may be multiple variables that are equal to the same stripe - // expression. Keep track of those in pointExprToStripe. - // There may also be multiple stripe expressions equal to the same variable. - // Introduce a temporary variable for each of those. - DenseMap> pointExprToStripe; - unsigned numTemporaries = 0; - - auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe, - linearPosition](SDBMInputExpr input, - SDBMExpr expr) { - unsigned position = linearPosition(input); - if (result.stripeToPoint.count(position) && - result.stripeToPoint[position] != expr) { - position = result.getNumVariables() + numTemporaries++; - } - pointExprToStripe[expr].push_back(position); - result.stripeToPoint.insert(std::make_pair(position, expr)); - }; - - for (auto eq : equalities) { - auto diffExpr = eq.dyn_cast(); - if (!diffExpr) - continue; - - auto lhs = diffExpr.getLHS(); - auto rhs = diffExpr.getRHS(); - auto lhsInput = lhs.dyn_cast(); - auto rhsInput = rhs.dyn_cast(); - - if (lhsInput && stripes.count(rhs)) - updateStripePointMaps(lhsInput, rhs); - if (rhsInput && stripes.count(lhs)) - updateStripePointMaps(rhsInput, lhs); - } - - // Assign the remaining stripe expressions to temporary variables. These - // expressions are the ones that could not be associated with an existing - // variable in the previous step. - for (auto expr : stripes) { - if (pointExprToStripe.count(expr)) - continue; - unsigned position = result.getNumVariables() + numTemporaries++; - pointExprToStripe[expr].push_back(position); - result.stripeToPoint.insert(std::make_pair(position, expr)); - } - - // Create the DBM matrix, initialized to infinity values for the least tight - // possible bound (x - y <= infinity is always true). - result.numTemporaries = numTemporaries; - result.matrix.resize(result.getNumVariables() * result.getNumVariables(), - IntInfty::infinity()); - - SDBMBuilder builder(pointExprToStripe, linearPosition); - - // Only keep the tightest constraint. Since we transform everything into - // less-than-or-equals-to inequalities, keep the smallest constant. For - // example, if we have d0 - d1 <= 42 and d0 - d1 <= 2, we keep the latter. - // Note that the input expressions are in the shape of d0 - d1 + -42 <= 0 - // so we negate the value before storing it. - // In case where the positive and the negative positions are equal, the - // corresponding expression has the form d0 - d0 + -42 <= 0. If the constant - // value is positive, the set defined by SDBM is trivially empty. We store - // this value anyway and continue processing to maintain the correspondence - // between the matrix form and the list-of-SDBMExpr form. - // TODO: we may want to reconsider this once we have canonicalization - // or simplification in place - auto updateMatrix = [](SDBM &sdbm, const SDBMBuilderResult &r) { - for (auto positivePos : r.positivePos) { - for (auto negativePos : r.negativePos) { - auto &m = sdbm.at(negativePos, positivePos); - m = m < -r.value ? m : -r.value; - } - } - }; - - // Do the second sweep on (in)equalities, updating the SDB matrix to reflect - // the constraints. - for (auto ineq : inequalities) - updateMatrix(result, builder.visit(ineq)); - - // An equality f(x) = 0 is represented as a pair of inequalities {f(x) >= 0; - // f(x) <= 0} or, alternatively, {-f(x) <= 0 and f(x) <= 0}. - for (auto eq : equalities) { - updateMatrix(result, builder.visit(eq)); - updateMatrix(result, builder.visit(-eq)); - } - - // Add the inequalities induced by stripe equalities. - // t = x # C => t <= x <= t + C - 1 - // which is equivalent to - // {t - x <= 0; - // x - t - (C - 1) <= 0}. - for (const auto &pair : result.stripeToPoint) { - auto stripe = pair.second.cast(); - SDBMBuilderResult update = builder.visit(stripe.getLHS()); - assert(update.negativePos.size() == 1 && update.negativePos[0] == 0 && - "unexpected negated variable in stripe expression"); - assert(update.value == 0 && - "unexpected non-zero value in stripe expression"); - update.negativePos.clear(); - update.negativePos.push_back(pair.first); - update.value = -(stripe.getStripeFactor().getValue() - 1); - updateMatrix(result, update); - - std::swap(update.negativePos, update.positivePos); - update.value = 0; - updateMatrix(result, update); - } - - return result; -} - -// Given a row and a column position in the square DBM, insert one equality -// or up to two inequalities that correspond the entries (col, row) and (row, -// col) in the DBM. `rowExpr` and `colExpr` contain the expressions such that -// colExpr - rowExpr <= V where V is the value at (row, col) in the DBM. -// If one of the expressions is derived from another using a stripe operation, -// check if the inequalities induced by the stripe operation subsume the -// inequalities defined in the DBM and if so, elide these inequalities. -void SDBM::convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr, - SDBMTermExpr colExpr, - SmallVectorImpl &inequalities, - SmallVectorImpl &equalities) { - using ops_assertions::operator+; - using ops_assertions::operator-; - - auto diffIJValue = at(col, row); - auto diffJIValue = at(row, col); - - // If symmetric entries are opposite, the corresponding expressions are equal. - if (diffIJValue.isFinite() && - diffIJValue.getValue() == -diffJIValue.getValue()) { - equalities.push_back(rowExpr - colExpr - diffIJValue.getValue()); - return; - } - - // Given an inequality x0 - x1 <= A, check if x0 is a stripe variable derived - // from x1: x0 = x1 # B. If so, it would imply the constraints - // x0 <= x1 <= x0 + (B - 1) <=> x0 - x1 <= 0 and x1 - x0 <= (B - 1). - // Therefore, if A >= 0, this inequality is subsumed by that implied - // by the stripe equality and thus can be elided. - // Similarly, check if x1 is a stripe variable derived from x0: x1 = x0 # C. - // If so, it would imply the constraints x1 <= x0 <= x1 + (C - 1) <=> - // <=> x1 - x0 <= 0 and x0 - x1 <= (C - 1). Therefore, if A >= (C - 1), this - // inequality can be elided. - // - // Note: x0 and x1 may be a stripe expressions themselves, we rely on stripe - // expressions being stored without temporaries on the RHS and being passed - // into this function as is. - auto canElide = [this](unsigned x0, unsigned x1, SDBMExpr x0Expr, - SDBMExpr x1Expr, int64_t value) { - if (stripeToPoint.count(x0)) { - auto stripe = stripeToPoint[x0].cast(); - SDBMDirectExpr var = stripe.getLHS(); - if (x1Expr == var && value >= 0) - return true; - } - if (stripeToPoint.count(x1)) { - auto stripe = stripeToPoint[x1].cast(); - SDBMDirectExpr var = stripe.getLHS(); - if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1) - return true; - } - return false; - }; - - // Check row - col. - if (diffIJValue.isFinite() && - !canElide(row, col, rowExpr, colExpr, diffIJValue.getValue())) { - inequalities.push_back(rowExpr - colExpr - diffIJValue.getValue()); - } - // Check col - row. - if (diffJIValue.isFinite() && - !canElide(col, row, colExpr, rowExpr, diffJIValue.getValue())) { - inequalities.push_back(colExpr - rowExpr - diffJIValue.getValue()); - } -} - -// The values on the main diagonal correspond to the upper bound on the -// difference between a variable and itself: d0 - d0 <= C, or alternatively -// to -C <= 0. Only construct the inequalities when C is negative, which -// are trivially false but necessary for the returned system of inequalities -// to indicate that the set it defines is empty. -void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr, - SmallVectorImpl &inequalities) { - auto selfDifference = at(pos, pos); - if (selfDifference.isFinite() && selfDifference < 0) { - auto selfDifferenceValueExpr = - SDBMConstantExpr::get(expr.getDialect(), -selfDifference.getValue()); - inequalities.push_back(selfDifferenceValueExpr); - } -} - -void SDBM::getSDBMExpressions(SDBMDialect *dialect, - SmallVectorImpl &inequalities, - SmallVectorImpl &equalities) { - using ops_assertions::operator-; - using ops_assertions::operator+; - - // Helper function that creates an SDBMInputExpr given the linearized position - // of variable in the DBM. - auto getInput = [dialect, this](unsigned matrixPos) -> SDBMInputExpr { - if (matrixPos < numDims) - return SDBMDimExpr::get(dialect, matrixPos); - return SDBMSymbolExpr::get(dialect, matrixPos - numDims); - }; - - // The top-left value corresponds to inequality 0 <= C. If C is negative, the - // set defined by SDBM is trivially empty and we add the constraint -C <= 0 to - // the list of inequalities. Otherwise, the constraint is trivially true and - // we ignore it. - auto difference = at(0, 0); - if (difference.isFinite() && difference < 0) { - inequalities.push_back( - SDBMConstantExpr::get(dialect, -difference.getValue())); - } - - // Traverse the segment of the matrix that involves non-temporary variables. - unsigned numTrueVariables = numDims + numSymbols; - for (unsigned i = 0; i < numTrueVariables; ++i) { - // The first row and column represent numerical upper and lower bound on - // each variable. Transform them into inequalities if they are finite. - auto upperBound = at(0, 1 + i); - auto lowerBound = at(1 + i, 0); - auto inputExpr = getInput(i); - if (upperBound.isFinite() && - upperBound.getValue() == -lowerBound.getValue()) { - equalities.push_back(inputExpr - upperBound.getValue()); - } else if (upperBound.isFinite()) { - inequalities.push_back(inputExpr - upperBound.getValue()); - } else if (lowerBound.isFinite()) { - inequalities.push_back(-inputExpr - lowerBound.getValue()); - } - - // Introduce trivially false inequalities if required by diagonal elements. - convertDBMDiagonalElement(1 + i, inputExpr, inequalities); - - // Introduce equalities or inequalities between non-temporary variables. - for (unsigned j = 0; j < i; ++j) { - convertDBMElement(1 + i, 1 + j, getInput(i), getInput(j), inequalities, - equalities); - } - } - - // Add equalities for stripe expressions that define non-temporary - // variables. Temporary variables will be substituted into their uses and - // should not appear in the resulting equalities. - for (const auto &stripePair : stripeToPoint) { - unsigned position = stripePair.first; - if (position < 1 + numTrueVariables) { - equalities.push_back(getInput(position - 1) - stripePair.second); - } - } - - // Add equalities / inequalities involving temporaries by replacing the - // temporaries with stripe expressions that define them. - for (unsigned i = 1 + numTrueVariables, e = getNumVariables(); i < e; ++i) { - // Mixed constraints involving one temporary (j) and one non-temporary (i) - // variable. - for (unsigned j = 0; j < numTrueVariables; ++j) { - convertDBMElement(i, 1 + j, stripeToPoint[i].cast(), - getInput(j), inequalities, equalities); - } - - // Constraints involving only temporary variables. - for (unsigned j = 1 + numTrueVariables; j < i; ++j) { - convertDBMElement(i, j, stripeToPoint[i].cast(), - stripeToPoint[j].cast(), inequalities, - equalities); - } - - // Introduce trivially false inequalities if required by diagonal elements. - convertDBMDiagonalElement(i, stripeToPoint[i].cast(), - inequalities); - } -} - -void SDBM::print(raw_ostream &os) { - unsigned numVariables = getNumVariables(); - - // Helper function that prints the name of the variable given its linearized - // position in the DBM. - auto getVarName = [this](unsigned matrixPos) -> std::string { - if (matrixPos == 0) - return "cst"; - matrixPos -= 1; - if (matrixPos < numDims) - return std::string(llvm::formatv("d{0}", matrixPos)); - matrixPos -= numDims; - if (matrixPos < numSymbols) - return std::string(llvm::formatv("s{0}", matrixPos)); - matrixPos -= numSymbols; - return std::string(llvm::formatv("t{0}", matrixPos)); - }; - - // Header row. - os << " cst"; - for (unsigned i = 1; i < numVariables; ++i) { - os << llvm::formatv(" {0,4}", getVarName(i)); - } - os << '\n'; - - // Data rows. - for (unsigned i = 0; i < numVariables; ++i) { - os << llvm::formatv("{0,-4}", getVarName(i)); - for (unsigned j = 0; j < numVariables; ++j) { - IntInfty value = operator()(i, j); - if (!value.isFinite()) - os << " inf"; - else - os << llvm::formatv(" {0,4}", value.getValue()); - } - os << '\n'; - } - - // Explanation of temporaries. - for (const auto &pair : stripeToPoint) { - os << getVarName(pair.first) << " = "; - pair.second.print(os); - os << '\n'; - } -} - -void SDBM::dump() { print(llvm::errs()); } diff --git a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp +++ /dev/null @@ -1,23 +0,0 @@ -//===- SDBMDialect.cpp - MLIR SDBM Dialect --------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/SDBM/SDBMDialect.h" -#include "SDBMExprDetail.h" - -using namespace mlir; - -SDBMDialect::SDBMDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context, TypeID::get()) { - uniquer.registerParametricStorageType(); - uniquer.registerParametricStorageType(); - uniquer.registerParametricStorageType(); - uniquer.registerParametricStorageType(); - uniquer.registerParametricStorageType(); -} - -SDBMDialect::~SDBMDialect() = default; diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ /dev/null @@ -1,732 +0,0 @@ -//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// A striped difference-bound matrix (SDBM) expression is a constant expression, -// an identifier, a binary expression with constant RHS and +, stripe operators -// or a difference expression between two identifiers. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/SDBM/SDBMExpr.h" -#include "SDBMExprDetail.h" -#include "mlir/Dialect/SDBM/SDBMDialect.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineExprVisitor.h" - -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; - -namespace { -/// A simple compositional matcher for AffineExpr -/// -/// Example usage: -/// -/// ```c++ -/// AffineExprMatcher x, C, m; -/// AffineExprMatcher pattern1 = ((x % C) * m) + x; -/// AffineExprMatcher pattern2 = x + ((x % C) * m); -/// if (pattern1.match(expr) || pattern2.match(expr)) { -/// ... -/// } -/// ``` -class AffineExprMatcherStorage; -class AffineExprMatcher { -public: - AffineExprMatcher(); - AffineExprMatcher(const AffineExprMatcher &other); - - AffineExprMatcher operator+(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::Add, *this, other); - } - AffineExprMatcher operator*(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::Mul, *this, other); - } - AffineExprMatcher floorDiv(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other); - } - AffineExprMatcher ceilDiv(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other); - } - AffineExprMatcher operator%(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::Mod, *this, other); - } - - AffineExpr match(AffineExpr expr); - AffineExpr matched(); - Optional getMatchedConstantValue(); - -private: - AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b); - AffineExprKind kind; // only used to match in binary op cases. - // A shared_ptr allows multiple references to same matcher storage without - // worrying about ownership or dealing with an arena. To be cleaned up if we - // go with this. - std::shared_ptr storage; -}; - -class AffineExprMatcherStorage { -public: - AffineExprMatcherStorage() {} - AffineExprMatcherStorage(const AffineExprMatcherStorage &other) - : subExprs(other.subExprs.begin(), other.subExprs.end()), - matched(other.matched) {} - AffineExprMatcherStorage(ArrayRef exprs) - : subExprs(exprs.begin(), exprs.end()) {} - AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b) - : subExprs({a, b}) {} - SmallVector subExprs; - AffineExpr matched; -}; -} // namespace - -AffineExprMatcher::AffineExprMatcher() - : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {} - -AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other) - : kind(other.kind), storage(other.storage) {} - -Optional AffineExprMatcher::getMatchedConstantValue() { - if (auto cst = storage->matched.dyn_cast()) - return cst.getValue(); - return None; -} - -AffineExpr AffineExprMatcher::match(AffineExpr expr) { - if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) { - if (storage->matched) - if (storage->matched != expr) - return AffineExpr(); - storage->matched = expr; - return storage->matched; - } - if (kind != expr.getKind()) { - return AffineExpr(); - } - if (auto bin = expr.dyn_cast()) { - if (!storage->subExprs.empty() && - !storage->subExprs[0].match(bin.getLHS())) { - return AffineExpr(); - } - if (!storage->subExprs.empty() && - !storage->subExprs[1].match(bin.getRHS())) { - return AffineExpr(); - } - if (storage->matched) - if (storage->matched != expr) - return AffineExpr(); - storage->matched = expr; - return storage->matched; - } - llvm_unreachable("binary expected"); -} - -AffineExpr AffineExprMatcher::matched() { return storage->matched; } - -AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, - AffineExprMatcher b) - : kind(k), storage(new AffineExprMatcherStorage(a, b)) { - storage->subExprs.push_back(a); - storage->subExprs.push_back(b); -} - -//===----------------------------------------------------------------------===// -// SDBMExpr -//===----------------------------------------------------------------------===// - -SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); } - -MLIRContext *SDBMExpr::getContext() const { - return impl->dialect->getContext(); -} - -SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; } - -void SDBMExpr::print(raw_ostream &os) const { - struct Printer : public SDBMVisitor { - Printer(raw_ostream &ostream) : prn(ostream) {} - - void visitSum(SDBMSumExpr expr) { - visit(expr.getLHS()); - prn << " + "; - visit(expr.getRHS()); - } - void visitDiff(SDBMDiffExpr expr) { - visit(expr.getLHS()); - prn << " - "; - visit(expr.getRHS()); - } - void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); } - void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); } - void visitStripe(SDBMStripeExpr expr) { - SDBMDirectExpr lhs = expr.getLHS(); - bool isTerm = lhs.isa(); - if (!isTerm) - prn << '('; - visit(lhs); - if (!isTerm) - prn << ')'; - prn << " # "; - visitConstant(expr.getStripeFactor()); - } - void visitNeg(SDBMNegExpr expr) { - bool isSum = expr.getVar().isa(); - prn << '-'; - if (isSum) - prn << '('; - visit(expr.getVar()); - if (isSum) - prn << ')'; - } - void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); } - - raw_ostream &prn; - }; - Printer printer(os); - printer.visit(*this); -} - -void SDBMExpr::dump() const { - print(llvm::errs()); - llvm::errs() << '\n'; -} - -namespace { -// Helper class to perform negation of an SDBM expression. -struct SDBMNegator : public SDBMVisitor { - // Any term expression is wrapped into a negation expression. - // -(x) = -x - SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); } - // A negation expression is unwrapped. - // -(-x) = x - SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); } - // The value of the constant is negated. - SDBMExpr visitConstant(SDBMConstantExpr expr) { - return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue()); - } - - // Terms of a difference are interchanged. Since only the LHS of a diff - // expression is allowed to be a sum with a constant, we need to recreate the - // sum with the negated value: - // -((x + C) - y) = (y - C) - x. - SDBMExpr visitDiff(SDBMDiffExpr expr) { - // If the LHS is just a term, we can do straightforward interchange. - if (auto term = expr.getLHS().dyn_cast()) - return SDBMDiffExpr::get(expr.getRHS(), term); - - auto sum = expr.getLHS().cast(); - auto cst = visitConstant(sum.getRHS()).cast(); - return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst), - sum.getLHS()); - } -}; -} // namespace - -SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); } - -//===----------------------------------------------------------------------===// -// SDBMSumExpr -//===----------------------------------------------------------------------===// - -SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) { - assert(lhs && "expected SDBM variable expression"); - assert(rhs && "expected SDBM constant"); - - // If LHS of a sum is another sum, fold the constant RHS parts. - if (auto lhsSum = lhs.dyn_cast()) { - lhs = lhsSum.getLHS(); - rhs = SDBMConstantExpr::get(rhs.getDialect(), - rhs.getValue() + lhsSum.getRHS().getValue()); - } - - StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); - return uniquer.get( - /*initFn=*/{}, static_cast(SDBMExprKind::Add), lhs, rhs); -} - -SDBMTermExpr SDBMSumExpr::getLHS() const { - return static_cast(impl)->lhs.cast(); -} - -SDBMConstantExpr SDBMSumExpr::getRHS() const { - return static_cast(impl)->rhs; -} - -AffineExpr SDBMExpr::getAsAffineExpr() const { - struct Converter : public SDBMVisitor { - AffineExpr visitSum(SDBMSumExpr expr) { - AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - return lhs + rhs; - } - - AffineExpr visitStripe(SDBMStripeExpr expr) { - AffineExpr lhs = visit(expr.getLHS()), - rhs = visit(expr.getStripeFactor()); - return lhs - (lhs % rhs); - } - - AffineExpr visitDiff(SDBMDiffExpr expr) { - AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - return lhs - rhs; - } - - AffineExpr visitDim(SDBMDimExpr expr) { - return getAffineDimExpr(expr.getPosition(), expr.getContext()); - } - - AffineExpr visitSymbol(SDBMSymbolExpr expr) { - return getAffineSymbolExpr(expr.getPosition(), expr.getContext()); - } - - AffineExpr visitNeg(SDBMNegExpr expr) { - return getAffineBinaryOpExpr(AffineExprKind::Mul, - getAffineConstantExpr(-1, expr.getContext()), - visit(expr.getVar())); - } - - AffineExpr visitConstant(SDBMConstantExpr expr) { - return getAffineConstantExpr(expr.getValue(), expr.getContext()); - } - } converter; - return converter.visit(*this); -} - -// Given a direct expression `expr`, add the given constant to it and pass the -// resulting expression to `builder` before returning its result. If the -// expression is already a sum expression, update its constant and extract the -// LHS if the constant becomes zero. Otherwise, construct a sum expression. -template -static Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, - bool negated, - function_ref builder) { - SDBMDialect *dialect = expr.getDialect(); - if (auto sumExpr = expr.dyn_cast()) { - if (negated) - constant = sumExpr.getRHS().getValue() - constant; - else - constant += sumExpr.getRHS().getValue(); - - if (constant != 0) { - auto sum = SDBMSumExpr::get(sumExpr.getLHS(), - SDBMConstantExpr::get(dialect, constant)); - return builder(sum); - } else { - return builder(sumExpr.getLHS()); - } - } - if (constant != 0) - return builder(SDBMSumExpr::get( - expr.cast(), - SDBMConstantExpr::get(dialect, negated ? -constant : constant))); - return expr; -} - -// Construct an expression lhs + constant while maintaining the canonical form -// of the SDBM expressions, in particular sink the constant expression to the -// nearest sum expression in the left subtree of the expression tree. -static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) { - if (auto lhsDiff = lhs.dyn_cast()) - return addConstantAndSink( - lhsDiff.getLHS(), constant, /*negated=*/false, - [lhsDiff](SDBMDirectExpr e) { - return SDBMDiffExpr::get(e, lhsDiff.getRHS()); - }); - if (auto lhsNeg = lhs.dyn_cast()) - return addConstantAndSink( - lhsNeg.getVar(), constant, /*negated=*/true, - [](SDBMDirectExpr e) { return SDBMNegExpr::get(e); }); - if (auto lhsSum = lhs.dyn_cast()) - return addConstantAndSink(lhsSum, constant, /*negated=*/false, - [](SDBMDirectExpr e) { return e; }); - if (constant != 0) - return SDBMSumExpr::get(lhs.cast(), - SDBMConstantExpr::get(lhs.getDialect(), constant)); - return lhs; -} - -// Build a difference expression given a direct expression and a negation -// expression. -static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) { - // Fold (x + C) - (x + D) = C - D. - if (lhs.getTerm() == rhs.getVar().getTerm()) - return SDBMConstantExpr::get( - lhs.getDialect(), lhs.getConstant() - rhs.getVar().getConstant()); - - return SDBMDiffExpr::get( - addConstantAndSink(lhs, -rhs.getVar().getConstant(), - /*negated=*/false, - [](SDBMDirectExpr e) { return e; }), - rhs.getVar().getTerm()); -} - -// Try folding an expression (lhs + rhs) where at least one of the operands -// contains a negated variable, i.e. is a negation or a difference expression. -static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) { - // If exactly one of LHS, RHS is a negation expression, we can construct - // a difference expression, which is a special kind in SDBM. - auto lhsDirect = lhs.dyn_cast(); - auto rhsDirect = rhs.dyn_cast(); - auto lhsNeg = lhs.dyn_cast(); - auto rhsNeg = rhs.dyn_cast(); - - if (lhsDirect && rhsNeg) - return buildDiffExpr(lhsDirect, rhsNeg); - if (lhsNeg && rhsDirect) - return buildDiffExpr(rhsDirect, lhsNeg); - - // If a subexpression appears in a diff expression on the LHS(RHS) of a - // sum expression where it also appears on the RHS(LHS) with the opposite - // sign, we can simplify it away and obtain the SDBM form. - auto lhsDiff = lhs.dyn_cast(); - auto rhsDiff = rhs.dyn_cast(); - - // -(x + A) + ((x + B) - y) = -(y + (A - B)) - if (lhsNeg && rhsDiff && - lhsNeg.getVar().getTerm() == rhsDiff.getLHS().getTerm()) { - int64_t constant = - lhsNeg.getVar().getConstant() - rhsDiff.getLHS().getConstant(); - // RHS of the diff is a term expression, its sum with a constant is a direct - // expression. - return SDBMNegExpr::get( - addConstant(rhsDiff.getRHS(), constant).cast()); - } - - // (x + A) + ((y + B) - x) = (y + B) + A. - if (lhsDirect && rhsDiff && lhsDirect.getTerm() == rhsDiff.getRHS()) - return addConstant(rhsDiff.getLHS(), lhsDirect.getConstant()); - - // ((x + A) - y) + (-(x + B)) = -(y + (B - A)). - if (lhsDiff && rhsNeg && - lhsDiff.getLHS().getTerm() == rhsNeg.getVar().getTerm()) { - int64_t constant = - rhsNeg.getVar().getConstant() - lhsDiff.getLHS().getConstant(); - // RHS of the diff is a term expression, its sum with a constant is a direct - // expression. - return SDBMNegExpr::get( - addConstant(lhsDiff.getRHS(), constant).cast()); - } - - // ((x + A) - y) + (y + B) = (x + A) + B. - if (rhsDirect && lhsDiff && rhsDirect.getTerm() == lhsDiff.getRHS()) - return addConstant(lhsDiff.getLHS(), rhsDirect.getConstant()); - - return {}; -} - -Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { - struct Converter : public AffineExprVisitor { - SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) { - auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - if (!lhs || !rhs) - return {}; - - // In a "add" AffineExpr, the constant always appears on the right. If - // there were two constants, they would have been folded away. - assert(!lhs.isa() && "non-canonical affine expression"); - - // If RHS is a constant, we can always extend the SDBM expression to - // include it by sinking the constant into the nearest sum expression. - if (auto rhsConstant = rhs.dyn_cast()) { - int64_t constant = rhsConstant.getValue(); - auto varying = lhs.dyn_cast(); - assert(varying && "unexpected uncanonicalized sum of constants"); - return addConstant(varying, constant); - } - - // Try building a difference expression if one of the values is negated, - // or check if a difference on either hand side cancels out the outer term - // so as to remain correct within SDBM. Return null otherwise. - return foldSumDiff(lhs, rhs); - } - - SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) { - // Attempt to recover a stripe expression "x # C = (x floordiv C) * C". - AffineExprMatcher x, C; - AffineExprMatcher pattern = (x.floorDiv(C)) * C; - if (pattern.match(expr)) { - if (SDBMExpr converted = visit(x.matched())) { - if (auto varConverted = converted.dyn_cast()) - // TODO: return varConverted.stripe(C.getConstantValue()); - return SDBMStripeExpr::get( - varConverted, - SDBMConstantExpr::get(dialect, - C.getMatchedConstantValue().getValue())); - } - } - - auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - if (!lhs || !rhs) - return {}; - - // In a "mul" AffineExpr, the constant always appears on the right. If - // there were two constants, they would have been folded away. - assert(!lhs.isa() && "non-canonical affine expression"); - auto rhsConstant = rhs.dyn_cast(); - if (!rhsConstant) - return {}; - - // The only supported "multiplication" expression is an SDBM is dimension - // negation, that is a product of dimension and constant -1. - if (rhsConstant.getValue() != -1) - return {}; - - if (auto lhsVar = lhs.dyn_cast()) - return SDBMNegExpr::get(lhsVar); - if (auto lhsDiff = lhs.dyn_cast()) - return SDBMNegator().visitDiff(lhsDiff); - - // Other multiplications are not allowed in SDBM. - return {}; - } - - SDBMExpr visitModExpr(AffineBinaryOpExpr expr) { - auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - if (!lhs || !rhs) - return {}; - - // 'mod' can only be converted to SDBM if its LHS is a direct expression - // and its RHS is a constant. Then it `x mod c = x - x stripe c`. - auto rhsConstant = rhs.dyn_cast(); - auto lhsVar = lhs.dyn_cast(); - if (!lhsVar || !rhsConstant) - return {}; - return SDBMDiffExpr::get(lhsVar, - SDBMStripeExpr::get(lhsVar, rhsConstant)); - } - - // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM - SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; } - SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; } - - // Dimensions, symbols and constants are converted trivially. - SDBMExpr visitConstantExpr(AffineConstantExpr expr) { - return SDBMConstantExpr::get(dialect, expr.getValue()); - } - SDBMExpr visitDimExpr(AffineDimExpr expr) { - return SDBMDimExpr::get(dialect, expr.getPosition()); - } - SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) { - return SDBMSymbolExpr::get(dialect, expr.getPosition()); - } - - SDBMDialect *dialect; - } converter; - converter.dialect = affine.getContext()->getOrLoadDialect(); - - if (auto result = converter.visit(affine)) - return result; - return None; -} - -//===----------------------------------------------------------------------===// -// SDBMDiffExpr -//===----------------------------------------------------------------------===// - -SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) { - assert(lhs && "expected SDBM dimension"); - assert(rhs && "expected SDBM dimension"); - - StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); - return uniquer.get(/*initFn=*/{}, lhs, rhs); -} - -SDBMDirectExpr SDBMDiffExpr::getLHS() const { - return static_cast(impl)->lhs; -} - -SDBMTermExpr SDBMDiffExpr::getRHS() const { - return static_cast(impl)->rhs; -} - -//===----------------------------------------------------------------------===// -// SDBMDirectExpr -//===----------------------------------------------------------------------===// - -SDBMTermExpr SDBMDirectExpr::getTerm() { - if (auto sum = dyn_cast()) - return sum.getLHS(); - return cast(); -} - -int64_t SDBMDirectExpr::getConstant() { - if (auto sum = dyn_cast()) - return sum.getRHS().getValue(); - return 0; -} - -//===----------------------------------------------------------------------===// -// SDBMStripeExpr -//===----------------------------------------------------------------------===// - -SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var, - SDBMConstantExpr stripeFactor) { - assert(var && "expected SDBM variable expression"); - assert(stripeFactor && "expected non-null stripe factor"); - if (stripeFactor.getValue() <= 0) - llvm::report_fatal_error("non-positive stripe factor"); - - StorageUniquer &uniquer = var.getDialect()->getUniquer(); - return uniquer.get( - /*initFn=*/{}, static_cast(SDBMExprKind::Stripe), var, - stripeFactor); -} - -SDBMDirectExpr SDBMStripeExpr::getLHS() const { - if (SDBMVaryingExpr lhs = static_cast(impl)->lhs) - return lhs.cast(); - return {}; -} - -SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const { - return static_cast(impl)->rhs; -} - -//===----------------------------------------------------------------------===// -// SDBMInputExpr -//===----------------------------------------------------------------------===// - -unsigned SDBMInputExpr::getPosition() const { - return static_cast(impl)->position; -} - -//===----------------------------------------------------------------------===// -// SDBMDimExpr -//===----------------------------------------------------------------------===// - -SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) { - assert(dialect && "expected non-null dialect"); - - auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) { - storage->dialect = dialect; - }; - - StorageUniquer &uniquer = dialect->getUniquer(); - return uniquer.get( - assignDialect, static_cast(SDBMExprKind::DimId), position); -} - -//===----------------------------------------------------------------------===// -// SDBMSymbolExpr -//===----------------------------------------------------------------------===// - -SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) { - assert(dialect && "expected non-null dialect"); - - auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) { - storage->dialect = dialect; - }; - - StorageUniquer &uniquer = dialect->getUniquer(); - return uniquer.get( - assignDialect, static_cast(SDBMExprKind::SymbolId), position); -} - -//===----------------------------------------------------------------------===// -// SDBMConstantExpr -//===----------------------------------------------------------------------===// - -SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) { - assert(dialect && "expected non-null dialect"); - - auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) { - storage->dialect = dialect; - }; - - StorageUniquer &uniquer = dialect->getUniquer(); - return uniquer.get(assignCtx, value); -} - -int64_t SDBMConstantExpr::getValue() const { - return static_cast(impl)->constant; -} - -//===----------------------------------------------------------------------===// -// SDBMNegExpr -//===----------------------------------------------------------------------===// - -SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) { - assert(var && "expected non-null SDBM direct expression"); - - StorageUniquer &uniquer = var.getDialect()->getUniquer(); - return uniquer.get(/*initFn=*/{}, var); -} - -SDBMDirectExpr SDBMNegExpr::getVar() const { - return static_cast(impl)->expr; -} - -SDBMExpr mlir::ops_assertions::operator+(SDBMExpr lhs, SDBMExpr rhs) { - if (auto folded = foldSumDiff(lhs, rhs)) - return folded; - assert(!(lhs.isa() && rhs.isa()) && - "a sum of negated expressions is a negation of a sum of variables and " - "not a correct SDBM"); - - // Fold (x - y) + (y - x) = 0. - auto lhsDiff = lhs.dyn_cast(); - auto rhsDiff = rhs.dyn_cast(); - if (lhsDiff && rhsDiff) { - if (lhsDiff.getLHS() == rhsDiff.getRHS() && - lhsDiff.getRHS() == rhsDiff.getLHS()) - return SDBMConstantExpr::get(lhs.getDialect(), 0); - } - - // If LHS is a constant and RHS is not, swap the order to get into a supported - // sum case. From now on, RHS must be a constant. - auto lhsConstant = lhs.dyn_cast(); - auto rhsConstant = rhs.dyn_cast(); - if (!rhsConstant && lhsConstant) { - std::swap(lhs, rhs); - std::swap(lhsConstant, rhsConstant); - } - assert(rhsConstant && "at least one operand must be a constant"); - - // Constant-fold if LHS is also a constant. - if (lhsConstant) - return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() + - rhsConstant.getValue()); - return addConstant(lhs.cast(), rhsConstant.getValue()); -} - -SDBMExpr mlir::ops_assertions::operator-(SDBMExpr lhs, SDBMExpr rhs) { - // Fold x - x == 0. - if (lhs == rhs) - return SDBMConstantExpr::get(lhs.getDialect(), 0); - - // LHS and RHS may be constants. - auto lhsConstant = lhs.dyn_cast(); - auto rhsConstant = rhs.dyn_cast(); - - // Constant fold if both LHS and RHS are constants. - if (lhsConstant && rhsConstant) - return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() - - rhsConstant.getValue()); - - // Replace a difference with a sum with a negated value if one of LHS and RHS - // is a constant: - // x - C == x + (-C); - // C - x == -x + C. - // This calls into operator+ for further simplification. - if (rhsConstant) - return lhs + (-rhsConstant); - if (lhsConstant) - return -rhs + lhsConstant; - - return buildDiffExpr(lhs.cast(), (-rhs).cast()); -} - -SDBMExpr mlir::ops_assertions::stripe(SDBMExpr expr, SDBMExpr factor) { - auto constantFactor = factor.cast(); - assert(constantFactor.getValue() > 0 && "non-positive stripe"); - - // Fold x # 1 = x. - if (constantFactor.getValue() == 1) - return expr; - - return SDBMStripeExpr::get(expr.cast(), constantFactor); -} diff --git a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h deleted file mode 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h +++ /dev/null @@ -1,137 +0,0 @@ -//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This holds implementation details of SDBMExpr, in particular underlying -// storage types. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_SDBMEXPRDETAIL_H -#define MLIR_IR_SDBMEXPRDETAIL_H - -#include "mlir/Dialect/SDBM/SDBMExpr.h" -#include "mlir/Support/StorageUniquer.h" - -namespace mlir { - -class SDBMDialect; - -namespace detail { - -// Base storage class for SDBMExpr. -struct SDBMExprStorage : public StorageUniquer::BaseStorage { - SDBMExprKind getKind() { return kind; } - - SDBMDialect *dialect; - SDBMExprKind kind; -}; - -// Storage class for SDBM sum and stripe expressions. -struct SDBMBinaryExprStorage : public SDBMExprStorage { - using KeyTy = std::tuple; - - bool operator==(const KeyTy &key) const { - return static_cast(std::get<0>(key)) == kind && - std::get<1>(key) == lhs && std::get<2>(key) == rhs; - } - - static SDBMBinaryExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->lhs = std::get<1>(key); - result->rhs = std::get<2>(key); - result->dialect = result->lhs.getDialect(); - result->kind = static_cast(std::get<0>(key)); - return result; - } - - SDBMDirectExpr lhs; - SDBMConstantExpr rhs; -}; - -// Storage class for SDBM difference expressions. -struct SDBMDiffExprStorage : public SDBMExprStorage { - using KeyTy = std::pair; - - bool operator==(const KeyTy &key) const { - return std::get<0>(key) == lhs && std::get<1>(key) == rhs; - } - - static SDBMDiffExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->lhs = std::get<0>(key); - result->rhs = std::get<1>(key); - result->dialect = result->lhs.getDialect(); - result->kind = SDBMExprKind::Diff; - return result; - } - - SDBMDirectExpr lhs; - SDBMTermExpr rhs; -}; - -// Storage class for SDBM constant expressions. -struct SDBMConstantExprStorage : public SDBMExprStorage { - using KeyTy = int64_t; - - bool operator==(const KeyTy &key) const { return constant == key; } - - static SDBMConstantExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->constant = key; - result->kind = SDBMExprKind::Constant; - return result; - } - - int64_t constant; -}; - -// Storage class for SDBM dimension and symbol expressions. -struct SDBMTermExprStorage : public SDBMExprStorage { - using KeyTy = std::pair; - - bool operator==(const KeyTy &key) const { - return kind == static_cast(key.first) && - position == key.second; - } - - static SDBMTermExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->kind = static_cast(key.first); - result->position = key.second; - return result; - } - - unsigned position; -}; - -// Storage class for SDBM negation expressions. -struct SDBMNegExprStorage : public SDBMExprStorage { - using KeyTy = SDBMDirectExpr; - - bool operator==(const KeyTy &key) const { return key == expr; } - - static SDBMNegExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->expr = key; - result->dialect = key.getDialect(); - result->kind = SDBMExprKind::Neg; - return result; - } - - SDBMDirectExpr expr; -}; - -} // end namespace detail -} // end namespace mlir - -#endif // MLIR_IR_SDBMEXPRDETAIL_H diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -1,5 +1,4 @@ add_subdirectory(CAPI) -add_subdirectory(SDBM) add_subdirectory(lib) if(MLIR_ENABLE_BINDINGS_PYTHON) @@ -75,7 +74,6 @@ mlir-lsp-server mlir-opt mlir-reduce - mlir-sdbm-api-test mlir-tblgen mlir-translate mlir_runner_utils diff --git a/mlir/test/SDBM/CMakeLists.txt b/mlir/test/SDBM/CMakeLists.txt deleted file mode 100644 --- a/mlir/test/SDBM/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -set(LLVM_LINK_COMPONENTS - Core - Support - ) - -add_llvm_executable(mlir-sdbm-api-test - sdbm-api-test.cpp -) - -llvm_update_compile_flags(mlir-sdbm-api-test) - -target_link_libraries(mlir-sdbm-api-test - PRIVATE - MLIRIR - MLIRSDBM - MLIRSupport -) - -target_include_directories(mlir-sdbm-api-test PRIVATE ..) diff --git a/mlir/test/SDBM/lit.local.cfg b/mlir/test/SDBM/lit.local.cfg deleted file mode 100644 --- a/mlir/test/SDBM/lit.local.cfg +++ /dev/null @@ -1 +0,0 @@ -config.suffixes.add('.cpp') diff --git a/mlir/test/SDBM/sdbm-api-test.cpp b/mlir/test/SDBM/sdbm-api-test.cpp deleted file mode 100644 --- a/mlir/test/SDBM/sdbm-api-test.cpp +++ /dev/null @@ -1,201 +0,0 @@ -//===- sdbm-api-test.cpp - Tests for SDBM expression APIs -----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -// RUN: mlir-sdbm-api-test | FileCheck %s - -#include "mlir/Dialect/SDBM/SDBM.h" -#include "mlir/Dialect/SDBM/SDBMDialect.h" -#include "mlir/Dialect/SDBM/SDBMExpr.h" -#include "mlir/IR/MLIRContext.h" - -#include "llvm/Support/raw_ostream.h" - -#include "APITest.h" - -using namespace mlir; - - -static MLIRContext *ctx() { - static thread_local MLIRContext context; - static thread_local bool once = - (context.getOrLoadDialect(), true); - (void)once; - return &context; -} - -static SDBMDialect *dialect() { - static thread_local SDBMDialect *d = nullptr; - if (!d) { - d = ctx()->getOrLoadDialect(); - } - return d; -} - -static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); } - -static SDBMExpr symb(unsigned pos) { - return SDBMSymbolExpr::get(dialect(), pos); -} - -namespace { - -using namespace mlir::ops_assertions; - -TEST_FUNC(SDBM_SingleConstraint) { - // Build an SDBM defined by - // d0 - 3 <= 0 <=> d0 <= 3. - auto sdbm = SDBM::get(dim(0) - 3, llvm::None); - - // CHECK: cst d0 - // CHECK-NEXT: cst inf 3 - // CHECK-NEXT: d0 inf inf - sdbm.print(llvm::outs()); -} - -TEST_FUNC(SDBM_Equality) { - // Build an SDBM defined by - // - // d0 - d1 - 3 = 0 - // <=> {d0 - d1 - 3 <= 0 and d0 - d1 - 3 >= 0} - // <=> {d0 - d1 <= 3 and d1 - d0 <= -3}. - auto sdbm = SDBM::get(llvm::None, dim(0) - dim(1) - 3); - - // CHECK: cst d0 d1 - // CHECK-NEXT: cst inf inf inf - // CHECK-NEXT: d0 inf inf -3 - // CHECK-NEXT: d1 inf 3 inf - sdbm.print(llvm::outs()); -} - -TEST_FUNC(SDBM_TrivialSimplification) { - // Build an SDBM defined by - // - // d0 - 3 <= 0 <=> d0 <= 3 - // d0 - 5 <= 0 <=> d0 <= 5 - // - // which should get simplified on construction to only the former. - auto sdbm = SDBM::get({dim(0) - 3, dim(0) - 5}, llvm::None); - - // CHECK: cst d0 - // CHECK-NEXT: cst inf 3 - // CHECK-NEXT: d0 inf inf - sdbm.print(llvm::outs()); -} - -TEST_FUNC(SDBM_StripeInducedIneqs) { - // Build an SDBM defined by d1 = d0 # 3, which induces the constraints - // - // d1 - d0 <= 0 - // d0 - d1 <= 3 - 1 = 2 - auto sdbm = SDBM::get(llvm::None, dim(1) - stripe(dim(0), 3)); - - // CHECK: cst d0 d1 - // CHECK-NEXT: cst inf inf inf - // CHECK-NEXT: d0 inf inf 0 - // CHECK-NEXT: d1 inf 2 0 - // CHECK-NEXT: d1 = d0 # 3 - sdbm.print(llvm::outs()); -} - -TEST_FUNC(SDBM_StripeTemporaries) { - // Build an SDBM defined by d0 # 3 <= 0, which creates a temporary - // t0 = d0 # 3 leading to a constraint t0 <= 0 and the stripe-induced - // constraints - // - // t0 - d0 <= 0 - // d0 - t0 <= 3 - 1 = 2 - auto sdbm = SDBM::get(stripe(dim(0), 3), llvm::None); - - // CHECK: cst d0 t0 - // CHECK-NEXT: cst inf inf 0 - // CHECK-NEXT: d0 inf inf 0 - // CHECK-NEXT: t0 inf 2 inf - // CHECK-NEXT: t0 = d0 # 3 - sdbm.print(llvm::outs()); -} - -TEST_FUNC(SDBM_ElideInducedInequalities) { - // Build an SDBM defined by a single stripe equality d0 = s0 # 3 and make sure - // the induced inequalities are not present after converting the SDBM back - // into lists of expressions. - auto sdbm = SDBM::get(llvm::None, {dim(0) - stripe(symb(0), 3)}); - - SmallVector eqs, ineqs; - sdbm.getSDBMExpressions(dialect(), ineqs, eqs); - // CHECK-EMPTY: - for (auto ineq : ineqs) - ineq.print(llvm::outs() << '\n'); - llvm::outs() << "\n"; - - // CHECK: d0 - s0 # 3 - // CHECK-EMPTY: - for (auto eq : eqs) - eq.print(llvm::outs() << '\n'); - llvm::outs() << "\n\n"; -} - -TEST_FUNC(SDBM_StripeTightening) { - // Build an SDBM defined by - // - // d0 = s0 # 3 # 5 - // s0 # 3 # 5 - d1 + 42 = 0 - // s0 # 3 - d0 <= 2 - // - // where the last inequality is tighter than that induced by the first stripe - // equality (s0 # 3 - d0 <= 5 - 1 = 4). Check that the conversion from SDBM - // back to the lists of constraints conserves both the stripe equality and the - // tighter inequality. - auto s = stripe(stripe(symb(0), 3), 5); - auto tight = stripe(symb(0), 3) - dim(0) - 2; - auto sdbm = SDBM::get({tight}, {s - dim(0), s - dim(1) + 42}); - - SmallVector eqs, ineqs; - sdbm.getSDBMExpressions(dialect(), ineqs, eqs); - // CHECK: s0 # 3 + -2 - d0 - // CHECK-EMPTY: - for (auto ineq : ineqs) - ineq.print(llvm::outs() << '\n'); - llvm::outs() << "\n"; - - // CHECK-DAG: d1 + -42 - d0 - // CHECK-DAG: d0 - s0 # 3 # 5 - for (auto eq : eqs) - eq.print(llvm::outs() << '\n'); - llvm::outs() << "\n\n"; -} - -TEST_FUNC(SDBM_StripeTransitive) { - // Build an SDBM defined by - // - // d0 = d1 # 3 - // d0 = d2 # 7 - // - // where the same dimension is declared equal to two stripe expressions over - // different variables. This is practically handled by introducing a - // temporary variable for the second stripe expression and adding an equality - // constraint between this variable and the original dimension variable. - auto sdbm = SDBM::get( - llvm::None, {stripe(dim(1), 3) - dim(0), stripe(dim(2), 7) - dim(0)}); - - // CHECK: cst d0 d1 d2 t0 - // CHECK-NEXT: cst inf inf inf inf inf - // CHECK-NEXT: d0 inf 0 2 inf 0 - // CHECK-NEXT: d1 inf 0 inf inf inf - // CHECK-NEXT: d2 inf inf inf inf 0 - // CHECK-NEXT: t0 inf 0 inf 6 inf - // CHECK-NEXT: t0 = d2 # 7 - // CHECK-NEXT: d0 = d1 # 3 - sdbm.print(llvm::outs()); -} - -} // end namespace - -int main() { - RUN_TESTS(); - return 0; -} diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -65,7 +65,6 @@ 'mlir-linalg-ods-gen', 'mlir-linalg-ods-yaml-gen', 'mlir-reduce', - 'mlir-sdbm-api-test', ] # The following tools are optional diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -21,7 +21,6 @@ // CHECK-NEXT: quant // CHECK-NEXT: rocdl // CHECK-NEXT: scf -// CHECK-NEXT: sdbm // CHECK-NEXT: shape // CHECK-NEXT: sparse_tensor // CHECK-NEXT: spv diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -11,5 +11,4 @@ add_subdirectory(IR) add_subdirectory(Pass) add_subdirectory(Rewrite) -add_subdirectory(SDBM) add_subdirectory(TableGen) diff --git a/mlir/unittests/SDBM/CMakeLists.txt b/mlir/unittests/SDBM/CMakeLists.txt deleted file mode 100644 --- a/mlir/unittests/SDBM/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -add_mlir_unittest(MLIRSDBMTests - SDBMTest.cpp -) -target_link_libraries(MLIRSDBMTests - PRIVATE - MLIRSDBM -) diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp deleted file mode 100644 --- a/mlir/unittests/SDBM/SDBMTest.cpp +++ /dev/null @@ -1,449 +0,0 @@ -//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/SDBM/SDBM.h" -#include "mlir/Dialect/SDBM/SDBMDialect.h" -#include "mlir/Dialect/SDBM/SDBMExpr.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/MLIRContext.h" -#include "gtest/gtest.h" - -#include "llvm/ADT/DenseSet.h" - -using namespace mlir; - - -static MLIRContext *ctx() { - static thread_local MLIRContext context; - context.getOrLoadDialect(); - return &context; -} - -static SDBMDialect *dialect() { - static thread_local SDBMDialect *d = nullptr; - if (!d) { - d = ctx()->getOrLoadDialect(); - } - return d; -} - -static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); } - -static SDBMExpr symb(unsigned pos) { - return SDBMSymbolExpr::get(dialect(), pos); -} - -namespace { - -using namespace mlir::ops_assertions; - -TEST(SDBMOperators, Add) { - auto expr = dim(0) + 42; - auto sumExpr = expr.dyn_cast(); - ASSERT_TRUE(sumExpr); - EXPECT_EQ(sumExpr.getLHS(), dim(0)); - EXPECT_EQ(sumExpr.getRHS().getValue(), 42); -} - -TEST(SDBMOperators, AddFolding) { - auto constant = SDBMConstantExpr::get(dialect(), 2) + 42; - auto constantExpr = constant.dyn_cast(); - ASSERT_TRUE(constantExpr); - EXPECT_EQ(constantExpr.getValue(), 44); - - auto expr = (dim(0) + 10) + 32; - auto sumExpr = expr.dyn_cast(); - ASSERT_TRUE(sumExpr); - EXPECT_EQ(sumExpr.getRHS().getValue(), 42); - - expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)); - auto diffExpr = expr.dyn_cast(); - ASSERT_TRUE(diffExpr); - EXPECT_EQ(diffExpr.getLHS(), dim(0)); - EXPECT_EQ(diffExpr.getRHS(), dim(1)); - - auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0); - EXPECT_EQ(inverted, expr); - - // Check that opposite values cancel each other, and that we elide the zero - // constant. - expr = dim(0) + 42; - auto onlyDim = expr - 42; - EXPECT_EQ(onlyDim, dim(0)); - - // Check that we can sink a constant under a negation. - expr = -(dim(0) + 2); - auto negatedSum = (expr + 10).dyn_cast(); - ASSERT_TRUE(negatedSum); - auto sum = negatedSum.getVar().dyn_cast(); - ASSERT_TRUE(sum); - EXPECT_EQ(sum.getRHS().getValue(), -8); - - // Sum with zero is the same as the original expression. - EXPECT_EQ(dim(0) + 0, dim(0)); - - // Sum of opposite differences is zero. - auto diffOfDiffs = - ((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast(); - EXPECT_EQ(diffOfDiffs.getValue(), 0); -} - -TEST(SDBMOperators, AddNegativeTerms) { - const int64_t A = 7; - const int64_t B = -5; - auto x = SDBMDimExpr::get(dialect(), 0); - auto y = SDBMDimExpr::get(dialect(), 1); - - // Check the simplification patterns in addition where one of the variables is - // cancelled out and the result remains an SDBM. - EXPECT_EQ(-(x + A) + ((x + B) - y), -(y + (A - B))); - EXPECT_EQ((x + A) + ((y + B) - x), (y + B) + A); - EXPECT_EQ(((x + A) - y) + (-(x + B)), -(y + (B - A))); - EXPECT_EQ(((x + A) - y) + (y + B), (x + A) + B); -} - -TEST(SDBMOperators, Diff) { - auto expr = dim(0) - dim(1); - auto diffExpr = expr.dyn_cast(); - ASSERT_TRUE(diffExpr); - EXPECT_EQ(diffExpr.getLHS(), dim(0)); - EXPECT_EQ(diffExpr.getRHS(), dim(1)); -} - -TEST(SDBMOperators, DiffFolding) { - auto constant = SDBMConstantExpr::get(dialect(), 10) - 3; - auto constantExpr = constant.dyn_cast(); - ASSERT_TRUE(constantExpr); - EXPECT_EQ(constantExpr.getValue(), 7); - - auto expr = dim(0) - 3; - auto sumExpr = expr.dyn_cast(); - ASSERT_TRUE(sumExpr); - EXPECT_EQ(sumExpr.getRHS().getValue(), -3); - - auto zero = dim(0) - dim(0); - constantExpr = zero.dyn_cast(); - ASSERT_TRUE(constantExpr); - EXPECT_EQ(constantExpr.getValue(), 0); - - // Check that the constant terms in difference-of-sums are folded. - // (d0 - 3) - (d1 - 5) = (d0 + 2) - d1 - auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast(); - ASSERT_TRUE(diffOfSums); - auto lhs = diffOfSums.getLHS().dyn_cast(); - ASSERT_TRUE(lhs); - EXPECT_EQ(lhs.getLHS(), dim(0)); - EXPECT_EQ(lhs.getRHS().getValue(), 2); - EXPECT_EQ(diffOfSums.getRHS(), dim(1)); - - // Check that identical dimensions with opposite signs cancel each other. - auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast(); - ASSERT_TRUE(cstOnly); - EXPECT_EQ(cstOnly.getValue(), 42); - - // Check that identical terms in sum of diffs cancel out. - auto dimOnly = (-dim(0) + (dim(0) - dim(1))); - EXPECT_EQ(dimOnly, -dim(1)); - dimOnly = (dim(0) - dim(1)) + (-dim(0)); - EXPECT_EQ(dimOnly, -dim(1)); - dimOnly = (dim(0) - dim(1)) + dim(1); - EXPECT_EQ(dimOnly, dim(0)); - dimOnly = dim(0) + (dim(1) - dim(0)); - EXPECT_EQ(dimOnly, dim(1)); - - // Top-level zero constant is fine. - cstOnly = (-symb(1) + symb(1)).dyn_cast(); - ASSERT_TRUE(cstOnly); - EXPECT_EQ(cstOnly.getValue(), 0); -} - -TEST(SDBMOperators, Negate) { - auto sum = dim(0) + 3; - auto negated = (-sum).dyn_cast(); - ASSERT_TRUE(negated); - EXPECT_EQ(negated.getVar(), sum); -} - -TEST(SDBMOperators, Stripe) { - auto expr = stripe(dim(0), 3); - auto stripeExpr = expr.dyn_cast(); - ASSERT_TRUE(stripeExpr); - EXPECT_EQ(stripeExpr.getLHS(), dim(0)); - EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3); -} - -TEST(SDBM, RoundTripEqs) { - // Build an SDBM defined by - // - // d0 = s0 # 3 # 5 - // s0 # 3 # 5 - d1 + 42 = 0 - // - // and perform a double round-trip between the "list of equalities" and SDBM - // representation. After the first round-trip, the equalities may be - // different due to simplification or equivalent substitutions (e.g., the - // second equality may become d0 - d1 + 42 = 0). However, there should not - // be any further simplification after the second round-trip, - - // Build the SDBM from a pair of equalities and extract back the lists of - // inequalities and equalities. Check that all equalities are properly - // detected and none of them decayed into inequalities. - auto s = stripe(stripe(symb(0), 3), 5); - auto sdbm = SDBM::get(llvm::None, {s - dim(0), s - dim(1) + 42}); - SmallVector eqs, ineqs; - sdbm.getSDBMExpressions(dialect(), ineqs, eqs); - ASSERT_TRUE(ineqs.empty()); - - // Do the second round-trip. - auto sdbm2 = SDBM::get(llvm::None, eqs); - SmallVector eqs2, ineqs2; - sdbm2.getSDBMExpressions(dialect(), ineqs2, eqs2); - ASSERT_EQ(eqs.size(), eqs2.size()); - - // Check that the sets of equalities are equal, their order is not relevant. - llvm::DenseSet eqSet, eq2Set; - eqSet.insert(eqs.begin(), eqs.end()); - eq2Set.insert(eqs2.begin(), eqs2.end()); - EXPECT_EQ(eqSet, eq2Set); -} - -TEST(SDBMExpr, Constant) { - // We can create constants and query them. - auto expr = SDBMConstantExpr::get(dialect(), 42); - EXPECT_EQ(expr.getValue(), 42); - - // Two separately created constants with identical values are trivially equal. - auto expr2 = SDBMConstantExpr::get(dialect(), 42); - EXPECT_EQ(expr, expr2); - - // Hierarchy is okay. - auto generic = static_cast(expr); - EXPECT_TRUE(generic.isa()); -} - -TEST(SDBMExpr, Dim) { - // We can create dimension expressions and query them. - auto expr = SDBMDimExpr::get(dialect(), 0); - EXPECT_EQ(expr.getPosition(), 0u); - - // Two separately created dimensions with the same position are trivially - // equal. - auto expr2 = SDBMDimExpr::get(dialect(), 0); - EXPECT_EQ(expr, expr2); - - // Hierarchy is okay. - auto generic = static_cast(expr); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - - // Dimensions are not Symbols. - auto symbol = SDBMSymbolExpr::get(dialect(), 0); - EXPECT_NE(expr, symbol); - EXPECT_FALSE(expr.isa()); -} - -TEST(SDBMExpr, Symbol) { - // We can create symbol expressions and query them. - auto expr = SDBMSymbolExpr::get(dialect(), 0); - EXPECT_EQ(expr.getPosition(), 0u); - - // Two separately created symbols with the same position are trivially equal. - auto expr2 = SDBMSymbolExpr::get(dialect(), 0); - EXPECT_EQ(expr, expr2); - - // Hierarchy is okay. - auto generic = static_cast(expr); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - - // Dimensions are not Symbols. - auto symbol = SDBMDimExpr::get(dialect(), 0); - EXPECT_NE(expr, symbol); - EXPECT_FALSE(expr.isa()); -} - -TEST(SDBMExpr, Stripe) { - auto cst2 = SDBMConstantExpr::get(dialect(), 2); - auto cst0 = SDBMConstantExpr::get(dialect(), 0); - auto var = SDBMSymbolExpr::get(dialect(), 0); - - // We can create stripe expressions and query them. - auto expr = SDBMStripeExpr::get(var, cst2); - EXPECT_EQ(expr.getLHS(), var); - EXPECT_EQ(expr.getStripeFactor(), cst2); - - // Two separately created stripe expressions with the same LHS and RHS are - // trivially equal. - auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(dialect(), 0), cst2); - EXPECT_EQ(expr, expr2); - - // Stripes can be nested. - SDBMStripeExpr::get(expr, SDBMConstantExpr::get(dialect(), 4)); - - // Non-positive stripe factors are not allowed. - EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive"); - - // Stripes can have sums on the LHS. - SDBMStripeExpr::get(SDBMSumExpr::get(var, cst2), cst2); - - // Hierarchy is okay. - auto generic = static_cast(expr); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); -} - -TEST(SDBMExpr, Neg) { - auto cst2 = SDBMConstantExpr::get(dialect(), 2); - auto var = SDBMSymbolExpr::get(dialect(), 0); - auto stripe = SDBMStripeExpr::get(var, cst2); - - // We can create negation expressions and query them. - auto expr = SDBMNegExpr::get(var); - EXPECT_EQ(expr.getVar(), var); - auto expr2 = SDBMNegExpr::get(stripe); - EXPECT_EQ(expr2.getVar(), stripe); - - // Neg expressions are trivially comparable. - EXPECT_EQ(expr, SDBMNegExpr::get(var)); - - // Hierarchy is okay. - auto generic = static_cast(expr); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); -} - -TEST(SDBMExpr, Sum) { - auto cst2 = SDBMConstantExpr::get(dialect(), 2); - auto var = SDBMSymbolExpr::get(dialect(), 0); - auto stripe = SDBMStripeExpr::get(var, cst2); - - // We can create sum expressions and query them. - auto expr = SDBMSumExpr::get(var, cst2); - EXPECT_EQ(expr.getLHS(), var); - EXPECT_EQ(expr.getRHS(), cst2); - auto expr2 = SDBMSumExpr::get(stripe, cst2); - EXPECT_EQ(expr2.getLHS(), stripe); - EXPECT_EQ(expr2.getRHS(), cst2); - - // Sum expressions are trivially comparable. - EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2)); - - // Hierarchy is okay. - auto generic = static_cast(expr); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); -} - -TEST(SDBMExpr, Diff) { - auto cst2 = SDBMConstantExpr::get(dialect(), 2); - auto var = SDBMSymbolExpr::get(dialect(), 0); - auto stripe = SDBMStripeExpr::get(var, cst2); - - // We can create sum expressions and query them. - auto expr = SDBMDiffExpr::get(var, stripe); - EXPECT_EQ(expr.getLHS(), var); - EXPECT_EQ(expr.getRHS(), stripe); - auto expr2 = SDBMDiffExpr::get(stripe, var); - EXPECT_EQ(expr2.getLHS(), stripe); - EXPECT_EQ(expr2.getRHS(), var); - - // Sum expressions are trivially comparable. - EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe)); - - // Hierarchy is okay. - auto generic = static_cast(expr); - EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); -} - -TEST(SDBMExpr, AffineRoundTrip) { - // Build an expression (s0 - s0 # 2) - auto cst2 = SDBMConstantExpr::get(dialect(), 2); - auto var = SDBMSymbolExpr::get(dialect(), 0); - auto stripe = SDBMStripeExpr::get(var, cst2); - auto expr = SDBMDiffExpr::get(var, stripe); - - // Check that it can be converted to AffineExpr and back, i.e. stripe - // detection works correctly. - Optional roundtripped = - SDBMExpr::tryConvertAffineExpr(expr.getAsAffineExpr()); - ASSERT_TRUE(roundtripped.hasValue()); - EXPECT_EQ(roundtripped, static_cast(expr)); - - // Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe - // detection supports nested expressions. - auto cst5 = SDBMConstantExpr::get(dialect(), 5); - auto outerStripe = SDBMStripeExpr::get(stripe, cst5); - roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr()); - ASSERT_TRUE(roundtripped.hasValue()); - EXPECT_EQ(roundtripped, static_cast(outerStripe)); - - // Check that ((s0 + 2) # 5) can be round-tripped through AffineExpr, i.e. - // stripe detection supports sum expressions. - auto inner = SDBMSumExpr::get(var, cst2); - auto stripeSum = SDBMStripeExpr::get(inner, cst5); - roundtripped = SDBMExpr::tryConvertAffineExpr(stripeSum.getAsAffineExpr()); - ASSERT_TRUE(roundtripped.hasValue()); - EXPECT_EQ(roundtripped, static_cast(stripeSum)); - - // Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a - // deeper expression tree. - auto sum = SDBMSumExpr::get(outerStripe, cst2); - auto diff = SDBMDiffExpr::get(sum, stripe); - roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr()); - ASSERT_TRUE(roundtripped.hasValue()); - EXPECT_EQ(roundtripped, static_cast(diff)); - - // Check a nested stripe-sum combination. - auto cst7 = SDBMConstantExpr::get(dialect(), 7); - auto nestedStripe = - SDBMStripeExpr::get(SDBMSumExpr::get(stripeSum, cst2), cst7); - diff = SDBMDiffExpr::get(nestedStripe, stripe); - roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr()); - ASSERT_TRUE(roundtripped.hasValue()); - EXPECT_EQ(roundtripped, static_cast(diff)); -} - -TEST(SDBMExpr, MatchStripeMulPattern) { - // Make sure conversion from AffineExpr recognizes multiplicative stripe - // pattern (x floordiv B) * B == x # B. - auto cst = getAffineConstantExpr(42, ctx()); - auto dim = getAffineDimExpr(0, ctx()); - auto floor = dim.floorDiv(cst); - auto mul = cst * floor; - Optional converted = SDBMStripeExpr::tryConvertAffineExpr(mul); - ASSERT_TRUE(converted.hasValue()); - EXPECT_TRUE(converted->isa()); -} - -TEST(SDBMExpr, NonSDBM) { - auto d0 = getAffineDimExpr(0, ctx()); - auto d1 = getAffineDimExpr(1, ctx()); - auto sum = d0 + d1; - auto c2 = getAffineConstantExpr(2, ctx()); - auto prod = d0 * c2; - auto ceildiv = d1.ceilDiv(c2); - - // The following are not valid SDBM expressions: - // - a sum of two variables - EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(sum).hasValue()); - // - a variable with coefficient other than 1 or -1 - EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(prod).hasValue()); - // - a ceildiv expression - EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(ceildiv).hasValue()); -} - -} // end namespace