diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -36,6 +36,7 @@ #include "llvm/Support/RWMutex.h" #include "llvm/Support/ThreadPool.h" #include "llvm/Support/raw_ostream.h" +#include #include #define DEBUG_TYPE "mlircontext" @@ -1006,6 +1007,44 @@ }); } +/// Check whether the arguments passed to the AffineMap::get() are consistent. +/// This method checks whether the highest index of dimensional identifier +/// present in result expressions is less than `dimCount` and the highest index +/// of symbolic identifier present in result expressions is less than +/// `symbolCount`. +[[nodiscard]] static bool willBeValidAffineMap(unsigned dimCount, + unsigned symbolCount, + ArrayRef results) { + unsigned numDimensions = 0; + unsigned numSymbols = 0; + std::function calculateDimAndSymbolCount = + [&](AffineExpr expr) { + if (auto binOpExpr = expr.dyn_cast()) { + calculateDimAndSymbolCount(binOpExpr.getLHS()); + calculateDimAndSymbolCount(binOpExpr.getRHS()); + return; + } + if (auto dimExpr = expr.dyn_cast()) + numDimensions = std::max(numDimensions, 1 + dimExpr.getPosition()); + if (auto symbolExpr = expr.dyn_cast()) { + numSymbols = std::max(numSymbols, 1 + symbolExpr.getPosition()); + } + }; + + for (AffineExpr expr : results) + calculateDimAndSymbolCount(expr); + + if ((numDimensions > dimCount) || (numSymbols > symbolCount)) { + LLVM_DEBUG( + llvm::dbgs() + << "maximum dimensional identifier position in result expression must " + "be less than `dimCount` and maximum symbolic identifier position " + "in result expression must be less than `symbolCount`\n"); + return false; + } + return true; +} + AffineMap AffineMap::get(MLIRContext *context) { return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); } @@ -1017,11 +1056,13 @@ AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, AffineExpr result) { + assert(willBeValidAffineMap(dimCount, symbolCount, {result})); return getImpl(dimCount, symbolCount, {result}, result.getContext()); } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, ArrayRef results, MLIRContext *context) { + assert(willBeValidAffineMap(dimCount, symbolCount, results)); return getImpl(dimCount, symbolCount, results, context); }