diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -546,6 +546,21 @@ return result; } +template +static void getMaxDimAndSymbol(ArrayRef exprsList, + int64_t &maxDim, int64_t &maxSym) { + for (const auto &exprs : exprsList) { + for (auto expr : exprs) { + expr.walk([&maxDim, &maxSym](AffineExpr e) { + if (auto d = e.dyn_cast()) + maxDim = std::max(maxDim, static_cast(d.getPosition())); + if (auto s = e.dyn_cast()) + maxSym = std::max(maxSym, static_cast(s.getPosition())); + }); + } + } +} + inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { map.print(os); return os; diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -215,21 +215,6 @@ return permutationMap; } -template -static void getMaxDimAndSymbol(ArrayRef exprsList, - int64_t &maxDim, int64_t &maxSym) { - for (const auto &exprs : exprsList) { - for (auto expr : exprs) { - expr.walk([&maxDim, &maxSym](AffineExpr e) { - if (auto d = e.dyn_cast()) - maxDim = std::max(maxDim, static_cast(d.getPosition())); - if (auto s = e.dyn_cast()) - maxSym = std::max(maxSym, static_cast(s.getPosition())); - }); - } - } -} - template static SmallVector inferFromExprList(ArrayRef exprsList) { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -1006,6 +1006,29 @@ }); } +/// Check whether the arguments passed to the AffineMap::get() are consistent. +/// This method checks whether the highest index of dimensional identifier +/// present in result expressions is less than `dimCount` and the highest index +/// of symbolic identifier present in result expressions is less than +/// `symbolCount`. +[[nodiscard]] static bool willBeValidAffineMap(unsigned dimCount, + unsigned symbolCount, + ArrayRef results) { + int64_t maxDimPosition = -1; + int64_t maxSymbolPosition = -1; + getMaxDimAndSymbol(ArrayRef>(results), maxDimPosition, + maxSymbolPosition); + if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) { + LLVM_DEBUG( + llvm::dbgs() + << "maximum dimensional identifier position in result expression must " + "be less than `dimCount` and maximum symbolic identifier position " + "in result expression must be less than `symbolCount`\n"); + return false; + } + return true; +} + AffineMap AffineMap::get(MLIRContext *context) { return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); } @@ -1017,11 +1040,13 @@ AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, AffineExpr result) { + assert(willBeValidAffineMap(dimCount, symbolCount, {result})); return getImpl(dimCount, symbolCount, {result}, result.getContext()); } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, ArrayRef results, MLIRContext *context) { + assert(willBeValidAffineMap(dimCount, symbolCount, results)); return getImpl(dimCount, symbolCount, results, context); }