diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h @@ -56,6 +56,7 @@ enum class Kind { FlatAffineConstraints, FlatAffineValueConstraints, + MultiAffineFunction, IntegerPolyhedron }; @@ -194,6 +195,11 @@ /// Adds an equality from the coefficients specified in `eq`. void addEquality(ArrayRef eq); + /// Eliminate the `posB^th` local identifier, replacing every instance of it + /// with the `posA^th` local identifier. This should be used when the two + /// local variables are known to always take the same values. + virtual void eliminateRedundantLocalId(unsigned posA, unsigned posB); + /// Removes identifiers of the specified kind with the specified pos (or /// within the specified range) from the system. The specified location is /// relative to the first identifier of the specified kind. @@ -273,6 +279,9 @@ /// Returns true if the given point satisfies the constraints, or false /// otherwise. + /// + /// Note: currently, if the polyhedron contains local ids, the values of + /// the local ids must also be provided. bool containsPoint(ArrayRef point) const; /// Find equality and pairs of inequality contraints identified by their diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -0,0 +1,222 @@ +//===- PWMAFunction.h - MLIR PWMAFunction Class------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Support for piece-wise multi-affine functions. These are functions that are +// defined on a domain that is a union of IntegerPolyhedrons, and on each domain +// the value of the function is a tuple of integers, with each value in the +// tuple being an affine expression in the ids of the IntegerPolyhedron. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H +#define MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H + +#include "mlir/Analysis/Presburger/IntegerPolyhedron.h" +#include "mlir/Analysis/Presburger/PresburgerSet.h" + +namespace mlir { + +/// This class represents a multi-affine function whose domain is given by an +/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a +/// tuple of integer values attached to every point in the polyhedron, with the +/// value of each element of the tuple given by an affine expression in the ids +/// of the polyhedron. For example we could have the domain +/// +/// (x, y) : (x >= 5, y >= x) +/// +/// and a tuple of three integers defined at every point in the polyhedron: +/// +/// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y). +/// +/// In this way every point in the polyhedron has a tuple of integers associated +/// with it. If the integer polyhedron has local ids, then the output +/// expressions can use them as well. The output expressions are represented as +/// a matrix with one row for every element in the output vector one column for +/// each id, and an extra column at the end for the constant term. +/// +/// Checking equality of two such functions is supported, as well as finding the +/// value of the function at a specified point. Note that local ids in the +/// domain are not yet supported for finding the value at a point. +class MultiAffineFunction : protected IntegerPolyhedron { +public: + MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output) + : IntegerPolyhedron(domain), output(output) {} + MultiAffineFunction(const Matrix &output, unsigned numDims, + unsigned numSymbols = 0, unsigned numLocals = 0) + : IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {} + + ~MultiAffineFunction() override = default; + Kind getKind() const override { return Kind::MultiAffineFunction; } + bool classof(const IntegerPolyhedron *poly) const { + return poly->getKind() == Kind::MultiAffineFunction; + } + unsigned getNumIds() const { return IntegerPolyhedron::getNumIds(); } + unsigned getNumDimIds() const { return IntegerPolyhedron::getNumDimIds(); } + unsigned getNumSymbolIds() const { + return IntegerPolyhedron::getNumSymbolIds(); + } + unsigned getNumDimAndSymbolIds() const { + return getNumDimIds() + getNumSymbolIds(); + } + unsigned getNumInputs() const { return getNumDimAndSymbolIds(); } + unsigned getNumLocalIds() const { + return IntegerPolyhedron::getNumLocalIds(); + } + unsigned getNumOutputs() const { return output.getNumRows(); } + bool isConsistent() const { return output.getNumColumns() == numIds + 1; } + const IntegerPolyhedron &getDomain() const { return *this; } + + bool hasCompatibleDimensions(const MultiAffineFunction &f) const { + return getNumDimIds() == f.getNumDimIds() && + getNumSymbolIds() == f.getNumSymbolIds() && + getNumOutputs() == f.getNumOutputs(); + } + + /// Insert `num` identifiers of the specified kind at position `pos`. + /// Positions are relative to the kind of identifier. The coefficient columns + /// corresponding to the added identifiers are initialized to zero. Return the + /// absolute column position (i.e., not relative to the kind of identifier) + /// of the first added identifier. + unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override { + unsigned absolutePos = getIdKindOffset(kind) + pos; + output.insertColumns(absolutePos, num); + return IntegerPolyhedron::insertId(kind, pos, num); + } + + /// Swap the posA^th identifier with the posB^th identifier. + void swapId(unsigned posA, unsigned posB) override { + output.swapColumns(posA, posB); + IntegerPolyhedron::swapId(posA, posB); + } + + /// Eliminate the `posB^th` local identifier, replacing every instance of it + /// with the `posA^th` local identifier. This should be used when the two + /// local variables are known to always take the same values. + void eliminateRedundantLocalId(unsigned posA, unsigned posB) override { + output.addToColumn(posB, posA, /*scale=*/1); + output.removeColumn(posB); + IntegerPolyhedron::eliminateRedundantLocalId(posA, posB); + } + + /// Return whether the outputs of the functions agree wherever both functions + /// are defined, i.e., the outputs should be equal for all points in the + /// intersection of the domains. + bool isEqualWhereDomainsOverlap(MultiAffineFunction fB) const; + + /// Return whether the two MultiAffineFunctions are equal. This is the case if + /// they lie in the same space, i.e. have the same dimensions, and their + /// domains are identical and their outputs are equal on their domain. + bool isEqual(const MultiAffineFunction &fB) const { + return hasCompatibleDimensions(fB) && getDomain().isEqual(fB.getDomain()) && + isEqualWhereDomainsOverlap(fB); + } + + /// Get the value of the function at the specified point. If the point lies + /// outside the domain, an empty optional is returned. + /// + /// Note: domains with local ids are not yet supported, and will assert-fail. + Optional> valueAt(ArrayRef point) const; + + void print(raw_ostream &os) const { + os << "Domain:"; + IntegerPolyhedron::print(os); + os << "Output:\n"; + output.print(os); + os << "\n"; + } + + void dump() const { print(llvm::errs()); } + +private: + /// The function's output is a tuple of integers, with the ith element of the + /// tuple defined by the affine expression given by the ith row of this output + /// matrix. + Matrix output; +}; + +/// This class represents a piece-wise MultiAffineFunction. This can be thought +/// of as a list of MultiAffineFunction with disjoint domains, with each having +/// their own affine expressions for their output tuples. For example, we could +/// have a function with two input variables (x, y), defined as +/// +/// f(x, y) = (2*x + y, y - 4) if x >= 0, y >= 0 +/// = (-2*x + y, y + 4) if x < 0, y < 0 +/// = (4, 1) if x < 0, y >= 0 +/// +/// Note that the domains all have to be *disjoint*. Otherwise, the behaviour of +/// this class is undefined. The domains need not cover all possible points; +/// this represents a partial function and so could be undefined at some points. +/// +/// As in PresburgerSets, the input ids are partitioned into dimension ids and +/// symbolic ids. +/// +/// Support is provided to compare equality of two such functions as well as +/// finding the value of the function at a point. Note that local ids in the +/// piece are not supported for the latter. +class PWMAFunction { +public: + PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs) + : numDims(numDims), numSymbols(numSymbols), numOutputs(numOutputs) { + assert(numOutputs >= 1 && "The function must output something!"); + } + + void addPiece(const MultiAffineFunction &f); + void addPiece(const IntegerPolyhedron &domain, const Matrix &output); + + const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; } + unsigned getNumPieces() const { return pieces.size(); } + unsigned getNumOutputs() const { return numOutputs; } + unsigned getNumInputs() const { return numDims + numSymbols; } + unsigned getNumDimIds() const { return numDims; } + unsigned getNumSymbolIds() const { return numSymbols; } + MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; } + + void assertOutputMatrixConsistent(const Matrix &output) const; + PresburgerSet getDomain() const; + + void assertPieceIsConsistent(const MultiAffineFunction &piece) const { + assert(piece.getNumInputs() == getNumInputs()); + assert(piece.getNumOutputs() == getNumOutputs()); + assert(piece.isConsistent()); + } + + bool hasCompatibleDimensions(const PWMAFunction &f) const { + return getNumDimIds() == f.getNumDimIds() && + getNumSymbolIds() == f.getNumSymbolIds() && + getNumOutputs() == f.getNumOutputs(); + } + + /// Return the value at the specified point and an empty optional if the + /// point does not lie in the domain. + /// + /// Note: domains with local ids are not yet supported, and will assert-fail. + Optional> valueAt(ArrayRef point) const; + + /// Return whether this and fB are equal as PWMAFunctions, i.e. whether they + /// have the same dimensions, the same domain and they take the same value at + /// every point in the domain. + bool isEqual(const PWMAFunction &fB) const; + + void print(raw_ostream &os) const; + void dump() const; + +private: + /// The list of pieces in this piece-wise MultiAffineFunction. + SmallVector pieces; + + /// The number of dimensions ids in the domains. + unsigned numDims; + /// The number of symbol ids in the domains. + unsigned numSymbols; + /// The number of output ids. + unsigned numOutputs; +}; + +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -3,6 +3,7 @@ LinearTransform.cpp Matrix.cpp PresburgerSet.cpp + PWMAFunction.cpp Simplex.cpp Utils.cpp diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp --- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp @@ -1057,24 +1057,17 @@ equalities.resizeVertically(pos); } -/// Eliminate `pos2^th` local identifier, replacing its every instance with -/// `pos1^th` local identifier. This function is intended to be used to remove -/// redundancy when local variables at position `pos1` and `pos2` are restricted -/// to have the same value. -static void eliminateRedundantLocalId(IntegerPolyhedron &poly, unsigned pos1, - unsigned pos2) { - - assert(pos1 < poly.getNumLocalIds() && "Invalid local id position"); - assert(pos2 < poly.getNumLocalIds() && "Invalid local id position"); - - unsigned localOffset = poly.getNumDimAndSymbolIds(); - pos1 += localOffset; - pos2 += localOffset; - for (unsigned i = 0, e = poly.getNumInequalities(); i < e; ++i) - poly.atIneq(i, pos1) += poly.atIneq(i, pos2); - for (unsigned i = 0, e = poly.getNumEqualities(); i < e; ++i) - poly.atEq(i, pos1) += poly.atEq(i, pos2); - poly.removeId(pos2); +void IntegerPolyhedron::eliminateRedundantLocalId(unsigned posA, + unsigned posB) { + assert(posA < getNumLocalIds() && "Invalid local id position"); + assert(posB < getNumLocalIds() && "Invalid local id position"); + + unsigned localOffset = getNumDimAndSymbolIds(); + posA += localOffset; + posB += localOffset; + inequalities.addToColumn(posB, posA, 1); + equalities.addToColumn(posB, posA, 1); + removeId(posB); } /// Adds additional local ids to the sets such that they both have the union @@ -1121,8 +1114,8 @@ // Merge function that merges the local variables in both sets by treating // them as the same identifier. auto merge = [&polyA, &polyB](unsigned i, unsigned j) -> bool { - eliminateRedundantLocalId(polyA, i, j); - eliminateRedundantLocalId(polyB, i, j); + polyA.eliminateRedundantLocalId(i, j); + polyB.eliminateRedundantLocalId(i, j); return true; }; diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -0,0 +1,126 @@ +//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/PWMAFunction.h" +#include "mlir/Analysis/Presburger/Simplex.h" + +using namespace mlir; + +SmallVector subtract(ArrayRef vecA, + ArrayRef vecB) { + assert(vecA.size() == vecB.size()); + SmallVector result; + result.reserve(vecA.size()); + for (unsigned i = 0, e = vecA.size(); i < e; ++i) + result.push_back(vecA[i] - vecB[i]); + return result; +} + +PresburgerSet PWMAFunction::getDomain() const { + PresburgerSet domain = + PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds()); + for (const MultiAffineFunction &piece : pieces) + domain.unionPolyInPlace(piece.getDomain()); + return domain; +} + +Optional> +MultiAffineFunction::valueAt(ArrayRef point) const { + assert(getNumLocalIds() == 0 && "Local ids are not yet supported!"); + assert(point.size() == getNumIds()); + + if (!getDomain().containsPoint(point)) + return {}; + + // The point lies in the domain, so we need to compute the output value. + // The matrix `output` has an affine expression in the ith row, corresponding + // to the expression for the ith value in the output vector. The last column + // of the matrix contains the constant term. Let v be the input point with + // a 1 appended at the end. We can see that output * v gives the desired + // output vector. + SmallVector pointHomogenous{llvm::to_vector(point)}; + pointHomogenous.push_back(1); + SmallVector result = + output.postMultiplyWithColumn(pointHomogenous); + assert(result.size() == getNumOutputs()); + return result; +} + +Optional> +PWMAFunction::valueAt(ArrayRef point) const { + assert(point.size() == getNumInputs()); + for (const MultiAffineFunction &piece : pieces) + if (Optional> output = piece.valueAt(point)) + return output; + return {}; +} + +bool MultiAffineFunction::isEqualWhereDomainsOverlap( + MultiAffineFunction fB) const { + if (!hasCompatibleDimensions(fB)) + return false; + + // commonFunc's output will be the output of this. + MultiAffineFunction commonFunc = *this; + // After this merge, commonFunc and fB have the same local ids; they are + // merged. + commonFunc.mergeLocalIds(fB); + // After this, commonFunc's domain will be the intersection of the domains of + // this and fB. + commonFunc.IntegerPolyhedron::append(fB); + + // commonDomainMatching contains the subset of the common domain + // where the outputs of this and fB match. + IntegerPolyhedron commonDomainMatching = commonFunc.getDomain(); + + // We can't use this->output below because its locals aren't merged with fB's. + for (unsigned row = 0, e = getNumOutputs(); row < e; ++row) + commonDomainMatching.addEquality( + subtract(commonFunc.output.getRow(row), fB.output.getRow(row))); + + // If the whole common domain is a subset of this, then they are equal and the + // two functions match on the whole common domain. + return PresburgerSet(commonFunc.getDomain()) + .isSubsetOf(PresburgerSet(commonDomainMatching)); +} + +/// Two PWMAFunctions are equal if they have the same dimensionalities, +/// the same domain, and take the same value at every point in the domain. +bool PWMAFunction::isEqual(const PWMAFunction &fB) const { + const PWMAFunction &fA = *this; + if (!hasCompatibleDimensions(fB)) + return false; + + // Check if, whenever the domains of a piece of fA and a piece of fB overlap, + // they take the same output value. If fA and fB have the same domain (checked + // below), then this check is sufficient to check that fA and fB have the same + // output value at every point in their domain. + for (const MultiAffineFunction &aPiece : fA.pieces) + for (const MultiAffineFunction &bPiece : fB.pieces) + if (!aPiece.isEqualWhereDomainsOverlap(bPiece)) + return false; + + return fA.getDomain().isEqual(fB.getDomain()); +} + +void PWMAFunction::addPiece(const MultiAffineFunction &f) { + assertPieceIsConsistent(f); + pieces.push_back(f); +} + +void PWMAFunction::addPiece(const IntegerPolyhedron &domain, + const Matrix &output) { + addPiece(MultiAffineFunction(domain, output)); +} + +void PWMAFunction::print(raw_ostream &os) const { + for (const MultiAffineFunction &piece : pieces) + piece.print(os); +} + +void PWMAFunction::dump() const { print(llvm::errs()); } diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -3,6 +3,7 @@ LinearTransformTest.cpp MatrixTest.cpp PresburgerSetTest.cpp + PWMAFunctionTest.cpp SimplexTest.cpp ../../Dialect/Affine/Analysis/AffineStructuresParser.cpp ) diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -0,0 +1,180 @@ +//===- SetTest.cpp - Tests for PresburgerSet ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains tests for PresburgerSet. The tests for union, +// intersection, subtract, and complement work by computing the operation on +// two sets and checking, for a set of points, that the resulting set contains +// the point iff the result is supposed to contain it. The test for isEqual just +// checks if the result for two sets matches the expected result. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/PWMAFunction.h" +#include "../../Dialect/Affine/Analysis/AffineStructuresParser.h" +#include "mlir/Analysis/Presburger/PresburgerSet.h" + +#include +#include + +namespace mlir { +/// Parses an IntegerPolyhedron from a StringRef. It is expected that the +/// string represents a valid IntegerSet, otherwise it will violate a gtest +/// assertion. +static IntegerPolyhedron parsePoly(StringRef str, MLIRContext *context) { + FailureOr poly = parseIntegerSetToFAC(str, context); + EXPECT_TRUE(succeeded(poly)); + return *poly; +} + +static Matrix readMatrix(unsigned numRow, unsigned numColumns, + ArrayRef> matrix) { + Matrix results(numRow, numColumns); + assert(matrix.size() == numRow); + for (unsigned i = 0; i < numRow; ++i) { + assert(matrix[i].size() == numColumns); + for (unsigned j = 0; j < numColumns; ++j) + results(i, j) = matrix[i][j]; + } + return results; +} + +static PWMAFunction parsePWAF( + unsigned numInputs, unsigned numOutputs, + ArrayRef, 8>>> + data, + unsigned numSymbols = 0) { + static MLIRContext context; + + PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs); + for (const auto &pair : data) { + IntegerPolyhedron domain = parsePoly(pair.first, &context); + result.addPiece( + domain, readMatrix(numOutputs, domain.getNumIds() + 1, pair.second)); + } + return result; +} + +/// Parse a list of StringRefs to IntegerPolyhedron and combine them into a +TEST(PWAFunctionTest, isEqual) { + MLIRContext context; + + // The output expressions are different but it doesn't matter given the domain + PWMAFunction idAtZeros = + parsePWAF(/*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). + {"(x, y) : (x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y). + }); + PWMAFunction idAtZeros2 = + parsePWAF(/*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y). + {"(x, y) : (x == 0)", {{30, 0, 0}, {0, 1, 0}}} // (30x, y). + }); + EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2)); + + PWMAFunction notIdAtZeros = + parsePWAF(/*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). + {"(x, y) : (x == 0)", {{1, 0, 0}, {0, 2, 0}}} // (x, 2y). + }); + EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros)); + + // These match at their intersection but one has a bigger domain. + PWMAFunction idNoNegNegQuadrant = + parsePWAF(/*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). + {"(x, y) : (y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y). + }); + PWMAFunction idOnlyPosX = + parsePWAF(/*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). + }); + EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX)); + + // Different representations of the same domain. + PWMAFunction sumPlusOne = parsePWAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (x >= 0)", {{1, 1, 1}}}, // x + y + 1. + {"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1. + {"(x, y) : (y >= 0)", {{1, 1, 1}}} // x + y + 1. + }); + PWMAFunction sumPlusOne2 = + parsePWAF(/*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1. + }); + EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2)); + + // Functions with zero input dimensions. + PWMAFunction noInputs1 = parsePWAF(/*numInputs=*/0, /*numOutputs=*/1, + { + {"() : ()", {{1}}}, // 1. + }); + PWMAFunction noInputs2 = parsePWAF(/*numInputs=*/0, /*numOutputs=*/1, + { + {"() : ()", {{2}}}, // 1. + }); + EXPECT_TRUE(noInputs1.isEqual(noInputs1)); + EXPECT_FALSE(noInputs1.isEqual(noInputs2)); + + // Mismatched dimensionalities. + EXPECT_FALSE(noInputs1.isEqual(sumPlusOne)); + EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne)); + + // Divisions. + // Domain is only multiples of 6; x = 6k for some k. + // x + 4(x/2) + 4(x/3) == 26k. + PWMAFunction mul2AndMul3 = parsePWAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)", + {{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3). + }); + PWMAFunction mul6 = parsePWAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6). + }); + EXPECT_TRUE(mul2AndMul3.isEqual(mul6)); + + PWMAFunction mul6diff = parsePWAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6). + }); + EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff)); + + PWMAFunction mul5 = parsePWAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5). + }); + EXPECT_FALSE(mul2AndMul3.isEqual(mul5)); +} + +using testing::ElementsAre; + +TEST(PWMAFunction, valueAt) { + PWMAFunction nonNegPWAF = parsePWAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y). + {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y) + }); + EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23)); + EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23)); + EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1)); + EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue()); +} + +} // namespace mlir