diff --git a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h --- a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h +++ b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h @@ -39,12 +39,16 @@ // The given vector is interpreted as a row vector v. Post-multiply v with // this transform, say T, and return vT. - SmallVector preMultiplyWithRow(ArrayRef rowVec) const; + SmallVector preMultiplyWithRow(ArrayRef rowVec) const { + return matrix.preMultiplyWithRow(rowVec); + } // The given vector is interpreted as a column vector v. Pre-multiply v with // this transform, say T, and return Tv. SmallVector - postMultiplyWithColumn(ArrayRef colVec) const; + postMultiplyWithColumn(ArrayRef colVec) const { + return matrix.postMultiplyWithColumn(colVec); + } private: Matrix matrix; diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -117,6 +117,15 @@ /// Negate the specified column. void negateColumn(unsigned column); + /// The given vector is interpreted as a row vector v. Post-multiply v with + /// this matrix, say M, and return vM. + SmallVector preMultiplyWithRow(ArrayRef rowVec) const; + + /// The given vector is interpreted as a column vector v. Pre-multiply v with + /// this matrix, say M, and return Mv. + SmallVector + postMultiplyWithColumn(ArrayRef colVec) const; + /// Resize the matrix to the specified dimensions. If a dimension is smaller, /// the values are truncated; if it is bigger, the new values are initialized /// to zero. diff --git a/mlir/lib/Analysis/Presburger/LinearTransform.cpp b/mlir/lib/Analysis/Presburger/LinearTransform.cpp --- a/mlir/lib/Analysis/Presburger/LinearTransform.cpp +++ b/mlir/lib/Analysis/Presburger/LinearTransform.cpp @@ -111,30 +111,6 @@ return {echelonCol, LinearTransform(std::move(resultMatrix))}; } -SmallVector -LinearTransform::preMultiplyWithRow(ArrayRef rowVec) const { - assert(rowVec.size() == matrix.getNumRows() && - "row vector dimension should match transform output dimension"); - - SmallVector result(matrix.getNumColumns(), 0); - for (unsigned col = 0, e = matrix.getNumColumns(); col < e; ++col) - for (unsigned i = 0, e = matrix.getNumRows(); i < e; ++i) - result[col] += rowVec[i] * matrix(i, col); - return result; -} - -SmallVector -LinearTransform::postMultiplyWithColumn(ArrayRef colVec) const { - assert(matrix.getNumColumns() == colVec.size() && - "column vector dimension should match transform input dimension"); - - SmallVector result(matrix.getNumRows(), 0); - for (unsigned row = 0, e = matrix.getNumRows(); row < e; row++) - for (unsigned i = 0, e = matrix.getNumColumns(); i < e; i++) - result[row] += matrix(row, i) * colVec[i]; - return result; -} - IntegerPolyhedron LinearTransform::applyTo(const IntegerPolyhedron &poly) const { IntegerPolyhedron result(poly.getNumIds()); diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -203,6 +203,29 @@ at(row, column) = -at(row, column); } +SmallVector +Matrix::preMultiplyWithRow(ArrayRef rowVec) const { + assert(rowVec.size() == getNumRows() && "Invalid row vector dimension!"); + + SmallVector result(getNumColumns(), 0); + for (unsigned col = 0, e = getNumColumns(); col < e; ++col) + for (unsigned i = 0, e = getNumRows(); i < e; ++i) + result[col] += rowVec[i] * at(i, col); + return result; +} + +SmallVector +Matrix::postMultiplyWithColumn(ArrayRef colVec) const { + assert(getNumColumns() == colVec.size() && + "Invalid column vector dimension!"); + + SmallVector result(getNumRows(), 0); + for (unsigned row = 0, e = getNumRows(); row < e; row++) + for (unsigned i = 0, e = getNumColumns(); i < e; i++) + result[row] += at(row, i) * colVec[i]; + return result; +} + void Matrix::print(raw_ostream &os) const { for (unsigned row = 0; row < nRows; ++row) { for (unsigned column = 0; column < nColumns; ++column)