diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -162,6 +162,10 @@ /// when the caller knows it is safe to do so. unsigned getDimPosition(unsigned idx) const; + /// Extracts the permuted position where given input index resides. + /// Fails when called on a non-permutation. + unsigned getPermutedPosition(unsigned input) const; + /// Return true if any affine expression involves AffineDimExpr `position`. bool isFunctionOfDim(unsigned position) const { return llvm::any_of(getResults(), [&](AffineExpr e) { diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -336,6 +336,14 @@ return getResult(idx).cast().getPosition(); } +unsigned AffineMap::getPermutedPosition(unsigned input) const { + assert(isPermutation() && "invalid permutation request"); + for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) + if (getDimPosition(i) == input) + return i; + llvm_unreachable("incorrect permutation request"); +} + /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, /// true otherwise.