diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -36,6 +36,7 @@ #include "llvm/Support/RWMutex.h" #include "llvm/Support/ThreadPool.h" #include "llvm/Support/raw_ostream.h" +#include #include #define DEBUG_TYPE "mlircontext" @@ -990,8 +991,6 @@ MLIRContext *context) { auto &impl = context->getImpl(); auto key = std::make_tuple(dimCount, symbolCount, results); - - // Safely get or create an AffineMap instance. return safeGetOrCreate( impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] { auto *res = impl.affineAllocator.Allocate(); @@ -1006,6 +1005,51 @@ }); } +/// Check whether the arguments passed to the AffineMap::get() are consistent. +/// This method checks whether the highest index of dimensional identifier +/// present in result expressions is less than `dimCount` and the highest index +/// of symbolic identifier present in result expressions is less than +/// `symbolCount`. +[[nodiscard]] static bool willBeValidAffineMap(unsigned dimCount, + unsigned symbolCount, + ArrayRef results) { + bool hasEncounteredDim = false; + bool hasEncounteredSymbol = false; + unsigned maxDimPosition = 0; + unsigned maxSymbolPosition = 0; + std::function calculateMaxDimAndSymbolPosition = + [&](AffineExpr expr) { + if (auto binOpExpr = expr.dyn_cast()) { + calculateMaxDimAndSymbolPosition(binOpExpr.getLHS()); + calculateMaxDimAndSymbolPosition(binOpExpr.getRHS()); + return; + } + if (auto dimExpr = expr.dyn_cast()) { + hasEncounteredDim = true; + maxDimPosition = std::max(maxDimPosition, dimExpr.getPosition()); + } + if (auto symbolExpr = expr.dyn_cast()) { + hasEncounteredSymbol = true; + maxSymbolPosition = + std::max(maxSymbolPosition, symbolExpr.getPosition()); + } + }; + + for (AffineExpr expr : results) + calculateMaxDimAndSymbolPosition(expr); + + if ((hasEncounteredDim && maxDimPosition >= dimCount) || + (hasEncounteredSymbol && maxSymbolPosition >= symbolCount)) { + LLVM_DEBUG( + llvm::dbgs() + << "maximum dimensional identifier position in result expression must " + "be less than `dimCount` and maximum symbolic identifier position " + "in result expression must be less than `symbolCount`\n"); + return false; + } + return true; +} + AffineMap AffineMap::get(MLIRContext *context) { return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); } @@ -1017,11 +1061,13 @@ AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, AffineExpr result) { + assert(willBeValidAffineMap(dimCount, symbolCount, {result})); return getImpl(dimCount, symbolCount, {result}, result.getContext()); } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, ArrayRef results, MLIRContext *context) { + assert(willBeValidAffineMap(dimCount, symbolCount, results)); return getImpl(dimCount, symbolCount, results, context); }