diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -166,9 +166,10 @@ /// when the caller knows it is safe to do so. unsigned getDimPosition(unsigned idx) const; - /// Extracts the permuted position where given input index resides. - /// Fails when called on a non-permutation. - unsigned getPermutedPosition(unsigned input) const; + /// Extracts the first result position where `input` dimension resides. + /// Returns `llvm::None` if `input` is not a dimension expression or cannot be + /// found in results. + Optional getResultPosition(AffineExpr input) const; /// Return true if any affine expression involves AffineDimExpr `position`. bool isFunctionOfDim(unsigned position) const { diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -300,7 +300,10 @@ auto order = enc.getDimOrdering(); if (order) { assert(order.isPermutation()); - return order.getPermutedPosition(d); + auto maybePos = + order.getResultPosition(getAffineDimExpr(d, enc.getContext())); + assert(maybePos.has_value()); + return *maybePos; } } return d; diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -328,12 +328,16 @@ return getResult(idx).cast().getPosition(); } -unsigned AffineMap::getPermutedPosition(unsigned input) const { - assert(isPermutation() && "invalid permutation request"); - for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) - if (getDimPosition(i) == input) +Optional AffineMap::getResultPosition(AffineExpr input) const { + if (!input.isa()) + return llvm::None; + + for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) { + if (getResult(i) == input) return i; - llvm_unreachable("incorrect permutation request"); + } + + return llvm::None; } /// Folds the results of the application of an affine map on the provided