diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -166,9 +166,10 @@ /// when the caller knows it is safe to do so. unsigned getDimPosition(unsigned idx) const; - /// Extracts the permuted position where given input index resides. - /// Fails when called on a non-permutation. - unsigned getPermutedPosition(unsigned input) const; + /// Extracts the permuted position where the given input index resides. + /// Returns `llvm::None` if the input index is projected. Asserts on + /// non-projected permutation maps. + Optional getPermutedPosition(unsigned input) const; /// Return true if any affine expression involves AffineDimExpr `position`. bool isFunctionOfDim(unsigned position) const { diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -300,7 +300,9 @@ auto order = enc.getDimOrdering(); if (order) { assert(order.isPermutation()); - return order.getPermutedPosition(d); + auto maybePos = order.getPermutedPosition(d); + assert(maybePos.has_value()); + return *maybePos; } } return d; diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -328,12 +328,12 @@ return getResult(idx).cast().getPosition(); } -unsigned AffineMap::getPermutedPosition(unsigned input) const { - assert(isPermutation() && "invalid permutation request"); +Optional AffineMap::getPermutedPosition(unsigned input) const { + assert(isProjectedPermutation() && "invalid projected permutation request"); for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) if (getDimPosition(i) == input) return i; - llvm_unreachable("incorrect permutation request"); + return llvm::None; } /// Folds the results of the application of an affine map on the provided