diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h @@ -56,6 +56,7 @@ enum class Kind { FlatAffineConstraints, FlatAffineValueConstraints, + MultiAffineFunction, IntegerPolyhedron }; @@ -194,6 +195,11 @@ /// Adds an equality from the coefficients specified in `eq`. void addEquality(ArrayRef eq); + /// Eliminate the `posB^th` local identifier, replacing every instance of it + /// with the `posA^th` local identifier. This should be used when the two + /// local variables are known to always take the same values. + virtual void eliminateRedundantLocalId(unsigned posA, unsigned posB); + /// Removes identifiers of the specified kind with the specified pos (or /// within the specified range) from the system. The specified location is /// relative to the first identifier of the specified kind. @@ -273,6 +279,9 @@ /// Returns true if the given point satisfies the constraints, or false /// otherwise. + /// + /// Note: currently, if the polyhedron contains local ids, the values of + /// the local ids must also be provided. bool containsPoint(ArrayRef point) const; /// Find equality and pairs of inequality contraints identified by their diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -0,0 +1,193 @@ +//===- PWMAFunction.h - MLIR PWMAFunction Class------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Support for piece-wise multi-affine functions. These are functions that are +// defined on a domain that is a union of IntegerPolyhedrons, and on each domain +// the value of the function is a tuple of integers, with each value in the +// tuple being an affine expression in the ids of the IntegerPolyhedron. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H +#define MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H + +#include "mlir/Analysis/Presburger/IntegerPolyhedron.h" +#include "mlir/Analysis/Presburger/PresburgerSet.h" + +namespace mlir { + +/// This class represents a multi-affine function whose domain is given by an +/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a +/// tuple of integer values attached to every point in the polyhedron, with the +/// value of each element of the tuple given by an affine expression in the ids +/// of the polyhedron. For example we could have the domain +/// +/// (x, y) : (x >= 5, y >= x) +/// +/// and a tuple of three integers defined at every point in the polyhedron: +/// +/// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y). +/// +/// In this way every point in the polyhedron has a tuple of integers associated +/// with it. If the integer polyhedron has local ids, then the output +/// expressions can use them as well. The output expressions are represented as +/// a matrix with one row for every element in the output vector one column for +/// each id, and an extra column at the end for the constant term. +/// +/// Checking equality of two such functions is supported, as well as finding the +/// value of the function at a specified point. Note that local ids in the +/// domain are not yet supported for finding the value at a point. +class MultiAffineFunction : protected IntegerPolyhedron { +public: + /// We use protected inheritance to avoid inheriting the whole public + /// interface of IntegerPolyhedron. These using declarations explicitly make + /// only the relevant functions part of the public interface. + using IntegerPolyhedron::getNumDimAndSymbolIds; + using IntegerPolyhedron::getNumDimIds; + using IntegerPolyhedron::getNumIds; + using IntegerPolyhedron::getNumLocalIds; + using IntegerPolyhedron::getNumSymbolIds; + + MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output) + : IntegerPolyhedron(domain), output(output) {} + MultiAffineFunction(const Matrix &output, unsigned numDims, + unsigned numSymbols = 0, unsigned numLocals = 0) + : IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {} + + ~MultiAffineFunction() override = default; + Kind getKind() const override { return Kind::MultiAffineFunction; } + bool classof(const IntegerPolyhedron *poly) const { + return poly->getKind() == Kind::MultiAffineFunction; + } + + unsigned getNumInputs() const { return getNumDimAndSymbolIds(); } + unsigned getNumOutputs() const { return output.getNumRows(); } + bool isConsistent() const { return output.getNumColumns() == numIds + 1; } + const IntegerPolyhedron &getDomain() const { return *this; } + + bool hasCompatibleDimensions(const MultiAffineFunction &f) const; + + /// Insert `num` identifiers of the specified kind at position `pos`. + /// Positions are relative to the kind of identifier. The coefficient columns + /// corresponding to the added identifiers are initialized to zero. Return the + /// absolute column position (i.e., not relative to the kind of identifier) + /// of the first added identifier. + unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; + + /// Swap the posA^th identifier with the posB^th identifier. + void swapId(unsigned posA, unsigned posB) override; + + /// Remove the specified range of ids. + void removeIdRange(unsigned idStart, unsigned idLimit) override; + + /// Eliminate the `posB^th` local identifier, replacing every instance of it + /// with the `posA^th` local identifier. This should be used when the two + /// local variables are known to always take the same values. + void eliminateRedundantLocalId(unsigned posA, unsigned posB) override; + + /// Return whether the outputs of `this` and `other` agree wherever both + /// functions are defined, i.e., the outputs should be equal for all points in + /// the intersection of the domains. + bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const; + + /// Return whether the `this` and `other` are equal. This is the case if + /// they lie in the same space, i.e. have the same dimensions, and their + /// domains are identical and their outputs are equal on their domain. + bool isEqual(const MultiAffineFunction &other) const; + + /// Get the value of the function at the specified point. If the point lies + /// outside the domain, an empty optional is returned. + /// + /// Note: domains with local ids are not yet supported, and will assert-fail. + Optional> valueAt(ArrayRef point) const; + + void print(raw_ostream &os) const; + + void dump() const; + +private: + /// The function's output is a tuple of integers, with the ith element of the + /// tuple defined by the affine expression given by the ith row of this output + /// matrix. + Matrix output; +}; + +/// This class represents a piece-wise MultiAffineFunction. This can be thought +/// of as a list of MultiAffineFunction with disjoint domains, with each having +/// their own affine expressions for their output tuples. For example, we could +/// have a function with two input variables (x, y), defined as +/// +/// f(x, y) = (2*x + y, y - 4) if x >= 0, y >= 0 +/// = (-2*x + y, y + 4) if x < 0, y < 0 +/// = (4, 1) if x < 0, y >= 0 +/// +/// Note that the domains all have to be *disjoint*. Otherwise, the behaviour of +/// this class is undefined. The domains need not cover all possible points; +/// this represents a partial function and so could be undefined at some points. +/// +/// As in PresburgerSets, the input ids are partitioned into dimension ids and +/// symbolic ids. +/// +/// Support is provided to compare equality of two such functions as well as +/// finding the value of the function at a point. Note that local ids in the +/// piece are not supported for the latter. +class PWMAFunction { +public: + PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs) + : numDims(numDims), numSymbols(numSymbols), numOutputs(numOutputs) { + assert(numOutputs >= 1 && "The function must output something!"); + } + + void addPiece(const MultiAffineFunction &piece); + void addPiece(const IntegerPolyhedron &domain, const Matrix &output); + + const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; } + unsigned getNumPieces() const { return pieces.size(); } + unsigned getNumOutputs() const { return numOutputs; } + unsigned getNumInputs() const { return numDims + numSymbols; } + unsigned getNumDimIds() const { return numDims; } + unsigned getNumSymbolIds() const { return numSymbols; } + MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; } + + void assertOutputMatrixConsistent(const Matrix &output) const; + PresburgerSet getDomain() const; + + void assertPieceIsConsistent(const MultiAffineFunction &piece) const; + + bool hasCompatibleDimensions(const MultiAffineFunction &f) const; + bool hasCompatibleDimensions(const PWMAFunction &f) const; + + /// Return the value at the specified point and an empty optional if the + /// point does not lie in the domain. + /// + /// Note: domains with local ids are not yet supported, and will assert-fail. + Optional> valueAt(ArrayRef point) const; + + /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether + /// they have the same dimensions, the same domain and they take the same + /// value at every point in the domain. + bool isEqual(const PWMAFunction &other) const; + + void print(raw_ostream &os) const; + void dump() const; + +private: + /// The list of pieces in this piece-wise MultiAffineFunction. + SmallVector pieces; + + /// The number of dimensions ids in the domains. + unsigned numDims; + /// The number of symbol ids in the domains. + unsigned numSymbols; + /// The number of output ids. + unsigned numOutputs; +}; + +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -3,6 +3,7 @@ LinearTransform.cpp Matrix.cpp PresburgerSet.cpp + PWMAFunction.cpp Simplex.cpp Utils.cpp diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp --- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp @@ -1065,24 +1065,17 @@ equalities.resizeVertically(pos); } -/// Eliminate `pos2^th` local identifier, replacing its every instance with -/// `pos1^th` local identifier. This function is intended to be used to remove -/// redundancy when local variables at position `pos1` and `pos2` are restricted -/// to have the same value. -static void eliminateRedundantLocalId(IntegerPolyhedron &poly, unsigned pos1, - unsigned pos2) { - - assert(pos1 < poly.getNumLocalIds() && "Invalid local id position"); - assert(pos2 < poly.getNumLocalIds() && "Invalid local id position"); - - unsigned localOffset = poly.getNumDimAndSymbolIds(); - pos1 += localOffset; - pos2 += localOffset; - for (unsigned i = 0, e = poly.getNumInequalities(); i < e; ++i) - poly.atIneq(i, pos1) += poly.atIneq(i, pos2); - for (unsigned i = 0, e = poly.getNumEqualities(); i < e; ++i) - poly.atEq(i, pos1) += poly.atEq(i, pos2); - poly.removeId(pos2); +void IntegerPolyhedron::eliminateRedundantLocalId(unsigned posA, + unsigned posB) { + assert(posA < getNumLocalIds() && "Invalid local id position"); + assert(posB < getNumLocalIds() && "Invalid local id position"); + + unsigned localOffset = getIdKindOffset(IdKind::Local); + posA += localOffset; + posB += localOffset; + inequalities.addToColumn(posB, posA, 1); + equalities.addToColumn(posB, posA, 1); + removeId(posB); } /// Adds additional local ids to the sets such that they both have the union @@ -1129,8 +1122,8 @@ // Merge function that merges the local variables in both sets by treating // them as the same identifier. auto merge = [&polyA, &polyB](unsigned i, unsigned j) -> bool { - eliminateRedundantLocalId(polyA, i, j); - eliminateRedundantLocalId(polyB, i, j); + polyA.eliminateRedundantLocalId(i, j); + polyB.eliminateRedundantLocalId(i, j); return true; }; diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -0,0 +1,191 @@ +//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/PWMAFunction.h" +#include "mlir/Analysis/Presburger/Simplex.h" + +using namespace mlir; + +// Return the result of subtracting the two given vectors pointwise. +// The vectors must be of the same size. +// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. +static SmallVector subtract(ArrayRef vecA, + ArrayRef vecB) { + assert(vecA.size() == vecB.size() && + "Cannot subtract vectors of differing lengths!"); + SmallVector result; + result.reserve(vecA.size()); + for (unsigned i = 0, e = vecA.size(); i < e; ++i) + result.push_back(vecA[i] - vecB[i]); + return result; +} + +PresburgerSet PWMAFunction::getDomain() const { + PresburgerSet domain = + PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds()); + for (const MultiAffineFunction &piece : pieces) + domain.unionPolyInPlace(piece.getDomain()); + return domain; +} + +Optional> +MultiAffineFunction::valueAt(ArrayRef point) const { + assert(getNumLocalIds() == 0 && "Local ids are not yet supported!"); + assert(point.size() == getNumIds() && "Point has incorrect dimensionality!"); + + if (!getDomain().containsPoint(point)) + return {}; + + // The point lies in the domain, so we need to compute the output value. + // The matrix `output` has an affine expression in the ith row, corresponding + // to the expression for the ith value in the output vector. The last column + // of the matrix contains the constant term. Let v be the input point with + // a 1 appended at the end. We can see that output * v gives the desired + // output vector. + SmallVector pointHomogenous{llvm::to_vector(point)}; + pointHomogenous.push_back(1); + SmallVector result = + output.postMultiplyWithColumn(pointHomogenous); + assert(result.size() == getNumOutputs()); + return result; +} + +Optional> +PWMAFunction::valueAt(ArrayRef point) const { + assert(point.size() == getNumInputs() && + "Point has incorrect dimensionality!"); + for (const MultiAffineFunction &piece : pieces) + if (Optional> output = piece.valueAt(point)) + return output; + return {}; +} + +void MultiAffineFunction::print(raw_ostream &os) const { + os << "Domain:"; + IntegerPolyhedron::print(os); + os << "Output:\n"; + output.print(os); + os << "\n"; +} + +void MultiAffineFunction::dump() const { print(llvm::errs()); } + +bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { + return hasCompatibleDimensions(other) && + getDomain().isEqual(other.getDomain()) && + isEqualWhereDomainsOverlap(other); +} + +unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos, + unsigned num) { + unsigned absolutePos = getIdKindOffset(kind) + pos; + output.insertColumns(absolutePos, num); + return IntegerPolyhedron::insertId(kind, pos, num); +} + +void MultiAffineFunction::swapId(unsigned posA, unsigned posB) { + output.swapColumns(posA, posB); + IntegerPolyhedron::swapId(posA, posB); +} + +void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) { + output.removeColumns(idStart, idLimit - idStart); + IntegerPolyhedron::removeIdRange(idStart, idLimit); +} + +void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA, + unsigned posB) { + output.addToColumn(posB, posA, /*scale=*/1); + IntegerPolyhedron::eliminateRedundantLocalId(posA, posB); +} + +bool MultiAffineFunction::isEqualWhereDomainsOverlap( + MultiAffineFunction other) const { + if (!hasCompatibleDimensions(other)) + return false; + + // `commonFunc` has the same output as `this`. + MultiAffineFunction commonFunc = *this; + // After this merge, `commonFunc` and `other` have the same local ids; they + // are merged. + commonFunc.mergeLocalIds(other); + // After this, the domain of `commonFunc` will be the intersection of the + // domains of `this` and `other`. + commonFunc.IntegerPolyhedron::append(other); + + // `commonDomainMatching` contains the subset of the common domain + // where the outputs of `this` and `other` match. + IntegerPolyhedron commonDomainMatching = commonFunc.getDomain(); + + // We can't use `this->output` below because its locals aren't merged with + // those of `other`. + for (unsigned row = 0, e = getNumOutputs(); row < e; ++row) + commonDomainMatching.addEquality( + subtract(commonFunc.output.getRow(row), other.output.getRow(row))); + + // If the whole common domain is a subset of commonDomainMatching, then they + // are equal and the two functions match on the whole common domain. + return commonFunc.getDomain().isSubsetOf(commonDomainMatching); +} + +/// Two PWMAFunctions are equal if they have the same dimensionalities, +/// the same domain, and take the same value at every point in the domain. +bool PWMAFunction::isEqual(const PWMAFunction &other) const { + if (!hasCompatibleDimensions(other)) + return false; + + if (!this->getDomain().isEqual(other.getDomain())) + return false; + + // Check if, whenever the domains of a piece of `this` and a piece of `other` + // overlap, they take the same output value. If `this` and `other` have the + // same domain (checked above), then this check passes iff the two functions + // have the same output at every point in the domain. + for (const MultiAffineFunction &aPiece : this->pieces) + for (const MultiAffineFunction &bPiece : other.pieces) + if (!aPiece.isEqualWhereDomainsOverlap(bPiece)) + return false; + return true; +} + +void PWMAFunction::addPiece(const MultiAffineFunction &piece) { + assert(hasCompatibleDimensions(piece) && + "Piece to be added is not compatible with this PWMAFunction!"); + assert(piece.isConsistent() && "Piece is internally inconsistent!"); + pieces.push_back(piece); +} + +void PWMAFunction::addPiece(const IntegerPolyhedron &domain, + const Matrix &output) { + addPiece(MultiAffineFunction(domain, output)); +} + +void PWMAFunction::print(raw_ostream &os) const { + os << pieces.size() << " pieces:\n"; + for (const MultiAffineFunction &piece : pieces) + piece.print(os); +} + +/// The hasCompatibleDimensions functions don't check the number of local ids; +/// functions are still compatible if they have differing number of locals. +bool MultiAffineFunction::hasCompatibleDimensions( + const MultiAffineFunction &f) const { + return getNumDimIds() == f.getNumDimIds() && + getNumSymbolIds() == f.getNumSymbolIds() && + getNumOutputs() == f.getNumOutputs(); +} +bool PWMAFunction::hasCompatibleDimensions(const MultiAffineFunction &f) const { + return getNumDimIds() == f.getNumDimIds() && + getNumSymbolIds() == f.getNumSymbolIds() && + getNumOutputs() == f.getNumOutputs(); +} +bool PWMAFunction::hasCompatibleDimensions(const PWMAFunction &f) const { + return getNumDimIds() == f.getNumDimIds() && + getNumSymbolIds() == f.getNumSymbolIds() && + getNumOutputs() == f.getNumOutputs(); +} diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -3,6 +3,7 @@ LinearTransformTest.cpp MatrixTest.cpp PresburgerSetTest.cpp + PWMAFunctionTest.cpp SimplexTest.cpp ../../Dialect/Affine/Analysis/AffineStructuresParser.cpp ) diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -0,0 +1,184 @@ +//===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains tests for PresburgerSet. The tests for union, +// intersection, subtract, and complement work by computing the operation on +// two sets and checking, for a set of points, that the resulting set contains +// the point iff the result is supposed to contain it. The test for isEqual just +// checks if the result for two sets matches the expected result. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/PWMAFunction.h" +#include "../../Dialect/Affine/Analysis/AffineStructuresParser.h" +#include "mlir/Analysis/Presburger/PresburgerSet.h" + +#include +#include + +namespace mlir { +using testing::ElementsAre; + +/// Parses an IntegerPolyhedron from a StringRef. It is expected that the +/// string represents a valid IntegerSet, otherwise it will violate a gtest +/// assertion. +static IntegerPolyhedron parsePoly(StringRef str, MLIRContext *context) { + FailureOr poly = parseIntegerSetToFAC(str, context); + EXPECT_TRUE(succeeded(poly)); + return *poly; +} + +static Matrix makeMatrix(unsigned numRow, unsigned numColumns, + ArrayRef> matrix) { + Matrix results(numRow, numColumns); + assert(matrix.size() == numRow); + for (unsigned i = 0; i < numRow; ++i) { + assert(matrix[i].size() == numColumns && + "Output expression has incorrect dimensionality!"); + for (unsigned j = 0; j < numColumns; ++j) + results(i, j) = matrix[i][j]; + } + return results; +} + +/// Construct a PWMAFunction given the dimensionalities and an array describing +/// the list of pieces. Each piece is given by a string describing the domain +/// and a 2D array that represents the output. +static PWMAFunction parsePWMAF( + unsigned numInputs, unsigned numOutputs, + ArrayRef, 8>>> + data, + unsigned numSymbols = 0) { + static MLIRContext context; + + PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs); + for (const auto &pair : data) { + IntegerPolyhedron domain = parsePoly(pair.first, &context); + result.addPiece( + domain, makeMatrix(numOutputs, domain.getNumIds() + 1, pair.second)); + } + return result; +} + +TEST(PWAFunctionTest, isEqual) { + MLIRContext context; + + // The output expressions are different but it doesn't matter because they are + // equal in this domain. + PWMAFunction idAtZeros = + parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). + {"(x, y) : (x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y). + }); + PWMAFunction idAtZeros2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y). + {"(x, y) : (x == 0)", {{30, 0, 0}, {0, 1, 0}}} // (30x, y). + }); + EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2)); + + PWMAFunction notIdAtZeros = + parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). + {"(x, y) : (x == 0)", {{1, 0, 0}, {0, 2, 0}}} // (x, 2y). + }); + EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros)); + + // These match at their intersection but one has a bigger domain. + PWMAFunction idNoNegNegQuadrant = + parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). + {"(x, y) : (y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y). + }); + PWMAFunction idOnlyPosX = + parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). + }); + EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX)); + + // Different representations of the same domain. + PWMAFunction sumPlusOne = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (x >= 0)", {{1, 1, 1}}}, // x + y + 1. + {"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1. + {"(x, y) : (y >= 0)", {{1, 1, 1}}} // x + y + 1. + }); + PWMAFunction sumPlusOne2 = + parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1. + }); + EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2)); + + // Functions with zero input dimensions. + PWMAFunction noInputs1 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1, + { + {"() : ()", {{1}}}, // 1. + }); + PWMAFunction noInputs2 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1, + { + {"() : ()", {{2}}}, // 1. + }); + EXPECT_TRUE(noInputs1.isEqual(noInputs1)); + EXPECT_FALSE(noInputs1.isEqual(noInputs2)); + + // Mismatched dimensionalities. + EXPECT_FALSE(noInputs1.isEqual(sumPlusOne)); + EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne)); + + // Divisions. + // Domain is only multiples of 6; x = 6k for some k. + // x + 4(x/2) + 4(x/3) == 26k. + PWMAFunction mul2AndMul3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)", + {{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3). + }); + PWMAFunction mul6 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6). + }); + EXPECT_TRUE(mul2AndMul3.isEqual(mul6)); + + PWMAFunction mul6diff = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6). + }); + EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff)); + + PWMAFunction mul5 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5). + }); + EXPECT_FALSE(mul2AndMul3.isEqual(mul5)); +} + +TEST(PWMAFunction, valueAt) { + PWMAFunction nonNegPWAF = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y). + {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y) + }); + EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23)); + EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23)); + EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1)); + EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue()); +} + +} // namespace mlir