diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -188,9 +188,9 @@ MLIRContext *context = getContext(); auto r_i = getAffineDimExpr(0, context); return SmallVector{ - AffineMap::get(1, 0, {r_i}), - AffineMap::get(1, 0, {r_i}), - AffineMap::get(1, 0, context)}; + AffineMap::get(1, 0, {r_i}, context), + AffineMap::get(1, 0, {r_i}, context), + AffineMap::get(1, 0, {}, context)}; } }]; @@ -215,8 +215,10 @@ AffineExpr i, r_j; bindDims(context, i, r_j); return SmallVector{ - AffineMap::get(2, 0, {i, r_j}), AffineMap::get(2, 0, {r_j}), - AffineMap::get(2, 0, {i})}; + AffineMap::get(2, 0, {i, r_j}, context), + AffineMap::get(2, 0, {r_j}, context), + AffineMap::get(2, 0, {i}, context) + }; } }]; @@ -242,9 +244,11 @@ MLIRContext *context = getContext(); AffineExpr i, j, r_k; bindDims(context, i, j, r_k); - return SmallVector{AffineMap::get(3, 0, {i, r_k}), - AffineMap::get(3, 0, {r_k, j}), - AffineMap::get(3, 0, {i, j})}; + return SmallVector{ + AffineMap::get(3, 0, {i, r_k}, context), + AffineMap::get(3, 0, {r_k, j},context), + AffineMap::get(3, 0, {i, j}, context) + }; } }]; @@ -403,15 +407,15 @@ auto ws = weightedPoolingInputIndex(*this, xs, zs); return SmallVector{ // filter[z[0], ..., z[N-1], q, k] - AffineMap::get(idx, 0, concat(concat(zs, qs), ks)), + AffineMap::get(idx, 0, concat(concat(zs, qs), ks), context), // input[b, // x[0]*s[0] + d[0]*z[0] - pad_low[0], // ... // x[N-1]*s[N-1] + d[N-1]*z[N-1] - pad_low[N-1], // q] - AffineMap::get(idx, 0, concat(concat(bs, ws), qs)), + AffineMap::get(idx, 0, concat(concat(bs, ws), qs), context), // output[b, x[0], ..., x[N-1], k] - AffineMap::get(idx, 0, concat(concat(bs, xs), ks))}; + AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)}; } }]; @@ -465,11 +469,11 @@ weightedPoolingInputIndex(*this, outputDims, windowDims); return SmallVector{ // input - AffineMap::get(idx, 0, inputDims), + AffineMap::get(idx, 0, inputDims, context), // windowDims - AffineMap::get(idx, 0, windowDims), + AffineMap::get(idx, 0, windowDims, context), // output - AffineMap::get(idx, 0, outputDims) + AffineMap::get(idx, 0, outputDims, context) }; } }]; diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -25,22 +25,24 @@ namespace mlir { inline bool isRowMajorMatmul(ArrayAttr indexingMaps) { + auto context = indexingMaps.getContext(); AffineExpr m, n, k; - bindDims(indexingMaps.getContext(), m, n, k); - auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k})); - auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n})); - auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n})); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, indexingMaps.getContext()); + bindDims(context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context)); + auto maps = ArrayAttr::get({mapA, mapB, mapC}, context); return indexingMaps == maps; } inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) { + auto context = indexingMaps.getContext(); AffineExpr m, n, k; - bindDims(indexingMaps.getContext(), m, n, k); - auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n})); - auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k})); - auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m})); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, indexingMaps.getContext()); + bindDims(context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context)); + auto maps = ArrayAttr::get({mapA, mapB, mapC}, context); return indexingMaps == maps; } diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -49,13 +49,13 @@ static AffineMap get(unsigned dimCount, unsigned symbolCount, MLIRContext *context); - /// Returns an affine map with `dimCount` dimensions and `symbolCount` symbols - /// mapping to the given results. The array of results cannot be empty. + /// Returns an affine map with `dimCount` dimensions and `symbolCount` mapping + /// to a single output dimension static AffineMap get(unsigned dimCount, unsigned symbolCount, - ArrayRef results); + AffineExpr result); /// Returns an affine map with `dimCount` dimensions and `symbolCount` mapping - /// to the given results, where the number of results can be zero. + /// to the given results. static AffineMap get(unsigned dimCount, unsigned symbolCount, ArrayRef results, MLIRContext *context); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1464,11 +1464,8 @@ lbExprs.push_back(expr); } - auto lbMap = lbExprs.empty() ? AffineMap() - : AffineMap::get(dimCount, symCount, lbExprs); - - auto ubMap = ubExprs.empty() ? AffineMap() - : AffineMap::get(dimCount, symCount, ubExprs); + auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context); + auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context); return {lbMap, ubMap}; } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -62,8 +62,8 @@ SmallVector lbSplatExpr(ubValueMap.getNumResults(), lbMap.getResult(0)); - auto lbMapSplat = - AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), lbSplatExpr); + auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), + lbSplatExpr, b.getContext()); AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands()); AffineValueMap tripCountValueMap; diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -99,11 +99,11 @@ genericOp.indexing_maps().getValue()[1].cast(); // The indexing map for the input should be `(i) -> (i)`. if (inputMap.getValue() != - AffineMap::get(1, 0, {getAffineDimExpr(0, op->getContext())})) + AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext()))) return llvm::None; // The indexing map for the input should be `(i) -> (0)`. if (outputMap.getValue() != - AffineMap::get(1, 0, {getAffineConstantExpr(0, op->getContext())})) + AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext()))) return llvm::None; return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp); diff --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp @@ -129,7 +129,7 @@ if (v1) { operands.push_back(v1); } - auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)}); + auto map = AffineMap::get(numDims, numSymbols, affCombiner(d0, d1)); // TODO: createOrFold when available. Operation *op = makeComposedAffineApply(ScopedContext::getBuilder(), diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -522,7 +522,8 @@ "Unexpected number of concatenated symbols"); auto numDims = dimValueToPosition.size(); auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); - auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs); + auto auxiliaryMap = + AffineMap::get(numDims, numSymbols, auxiliaryExprs, map.getContext()); LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: ")); @@ -2163,19 +2164,13 @@ void AffineParallelOp::build(Builder *builder, OperationState &result, ArrayRef ranges) { - // Default initialize empty maps. - auto lbMap = AffineMap::get(builder->getContext()); - auto ubMap = AffineMap::get(builder->getContext()); - // If there are ranges, set each to [0, N). - if (ranges.size()) { - SmallVector lbExprs(ranges.size(), - builder->getAffineConstantExpr(0)); - lbMap = AffineMap::get(0, 0, lbExprs); - SmallVector ubExprs; - for (int64_t range : ranges) - ubExprs.push_back(builder->getAffineConstantExpr(range)); - ubMap = AffineMap::get(0, 0, ubExprs); - } + SmallVector lbExprs(ranges.size(), + builder->getAffineConstantExpr(0)); + auto lbMap = AffineMap::get(0, 0, lbExprs, builder->getContext()); + SmallVector ubExprs; + for (int64_t range : ranges) + ubExprs.push_back(builder->getAffineConstantExpr(range)); + auto ubMap = AffineMap::get(0, 0, ubExprs, builder->getContext()); build(builder, result, lbMap, {}, ubMap, {}); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineValueMap.cpp b/mlir/lib/Dialect/Affine/IR/AffineValueMap.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineValueMap.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineValueMap.cpp @@ -51,8 +51,9 @@ diffExprs.push_back(normalizer.getAffineMap().getResult(i) - bMap.getResult(i)); - auto diffMap = AffineMap::get(normalizer.getNumDims(), - normalizer.getNumSymbols(), diffExprs); + auto diffMap = + AffineMap::get(normalizer.getNumDims(), normalizer.getNumSymbols(), + diffExprs, aMap.getContext()); canonicalizeMapAndOperands(&diffMap, &bOperands); diffMap = simplifyAffineMap(diffMap); res->reset(diffMap, bOperands); diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp @@ -143,8 +143,9 @@ boundExprs.push_back(dim + tileSizes[i]); boundExprs.append(origUbMap.getResults().begin(), origUbMap.getResults().end()); - auto ubMap = AffineMap::get(origUbMap.getNumDims() + 1, - origUbMap.getNumSymbols(), boundExprs); + auto ubMap = + AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols(), + boundExprs, b.getContext()); newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap); } else { // No need of the min expression. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -523,8 +523,10 @@ "Expected symbol-less expressions"); SmallVector maps; maps.reserve(reassociation.size()); - for (auto exprs : reassociation) - maps.push_back(AffineMap::get(maxDim + 1, 0, exprs)); + for (auto exprs : reassociation) { + assert(exprs.size() != 0); + maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); + } return maps; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -252,7 +252,8 @@ // so having a max op is enough. auto maxMap = AffineMap::get(/*dimCount=*/1, 0, {getAffineDimExpr(/*position=*/0, context), - getAffineConstantExpr(0, context)}); + getAffineConstantExpr(0, context)}, + context); clampedImIdx.push_back( affine_max(dim.getType(), maxMap, ValueRange{dim})); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -294,7 +294,8 @@ /*dimCount=*/3, /*symbolCount=*/0, {getAffineDimExpr(/*position=*/0, b.getContext()), getAffineDimExpr(/*position=*/1, b.getContext()) - - getAffineDimExpr(/*position=*/2, b.getContext())}); + getAffineDimExpr(/*position=*/2, b.getContext())}, + b.getContext()); auto d = folded_std_dim(folder, view, r); size = folded_affine_min(folder, b.getIndexType(), minMap, ValueRange{size, d, offset}); diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp --- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp @@ -66,7 +66,8 @@ /*dimCount=*/3, /*symbolCount=*/0, {getAffineDimExpr(/*position=*/0, b.getContext()), getAffineDimExpr(/*position=*/1, b.getContext()) - - getAffineDimExpr(/*position=*/2, b.getContext())}); + getAffineDimExpr(/*position=*/2, b.getContext())}, + b.getContext()); // Create the inner loop with adjusted bounds. SmallVector newBounds; diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1348,9 +1348,7 @@ auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); results.push_back(targetExpr); } - // The (...) -> () affine map has its own factory method. - return results.empty() ? AffineMap::get(map.getNumDims() - 1, 0, ctx) - : AffineMap::get(map.getNumDims() - 1, 0, results); + return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); } // Helper to drop dimension from vector type. diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -176,7 +176,7 @@ "Vectorization prerequisite violated: at most 1 index may be " "invariant wrt a vectorized loop"); } - return AffineMap::get(indices.size(), 0, perm); + return AffineMap::get(indices.size(), 0, perm, context); } /// Implementation detail that walks up the parents and records the ones with diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -104,7 +104,7 @@ for (auto index : permutation) affExprs.push_back(getAffineDimExpr(index, context)); auto m = std::max_element(permutation.begin(), permutation.end()); - auto permutationMap = AffineMap::get(*m + 1, 0, affExprs); + auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context); assert(permutationMap.isPermutation() && "Invalid permutation vector"); return permutationMap; } @@ -127,13 +127,16 @@ template static SmallVector inferFromExprList(ArrayRef exprsList) { + assert(!exprsList.empty()); + assert(!exprsList[0].empty()); + auto context = exprsList[0][0].getContext(); int64_t maxDim = -1, maxSym = -1; getMaxDimAndSymbol(exprsList, maxDim, maxSym); SmallVector maps; maps.reserve(exprsList.size()); for (const auto &exprs : exprsList) maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1, - /*symbolCount=*/maxSym + 1, exprs)); + /*symbolCount=*/maxSym + 1, exprs, context)); return maps; } @@ -153,7 +156,7 @@ dimExprs.reserve(numDims); for (unsigned i = 0; i < numDims; ++i) dimExprs.push_back(mlir::getAffineDimExpr(i, context)); - return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs); + return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context); } MLIRContext *AffineMap::getContext() const { return map->context; } @@ -255,8 +258,7 @@ results.push_back( expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); - return results.empty() ? get(numResultDims, 0, getContext()) - : get(numResultDims, numResultSyms, results); + return get(numResultDims, numResultSyms, results, getContext()); } AffineMap AffineMap::compose(AffineMap map) { @@ -280,8 +282,7 @@ exprs.reserve(getResults().size()); for (auto expr : getResults()) exprs.push_back(expr.compose(newMap)); - return exprs.empty() ? AffineMap::get(numDims, 0, map.getContext()) - : AffineMap::get(numDims, numSymbols, exprs); + return AffineMap::get(numDims, numSymbols, exprs, map.getContext()); } bool AffineMap::isProjectedPermutation() { @@ -312,7 +313,7 @@ for (auto idx : resultPos) { exprs.push_back(getResult(idx)); } - return AffineMap::get(getNumDims(), getNumSymbols(), exprs); + return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); } AffineMap mlir::simplifyAffineMap(AffineMap map) { @@ -321,7 +322,8 @@ exprs.push_back( simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); } - return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs); + return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, + map.getContext()); } AffineMap mlir::removeDuplicateExprs(AffineMap map) { @@ -354,7 +356,7 @@ seenExprs.push_back(expr); if (seenExprs.size() != map.getNumInputs()) return AffineMap(); - return AffineMap::get(map.getNumResults(), 0, seenExprs); + return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext()); } AffineMap mlir::concatAffineMaps(ArrayRef maps) { @@ -369,9 +371,8 @@ results.append(m.getResults().begin(), m.getResults().end()); numDims = std::max(m.getNumDims(), numDims); } - return results.empty() ? AffineMap::get(numDims, /*numSymbols=*/0, - maps.front().getContext()) - : AffineMap::get(numDims, /*numSymbols=*/0, results); + return AffineMap::get(numDims, /*numSymbols=*/0, results, + maps.front().getContext()); } //===----------------------------------------------------------------------===// @@ -380,8 +381,7 @@ MutableAffineMap::MutableAffineMap(AffineMap map) : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), - // A map always has at least 1 result by construction - context(map.getResult(0).getContext()) { + context(map.getContext()) { for (auto result : map.getResults()) results.push_back(result); } @@ -390,8 +390,7 @@ results.clear(); numDims = map.getNumDims(); numSymbols = map.getNumSymbols(); - // A map always has at least 1 result by construction - context = map.getResult(0).getContext(); + context = map.getContext(); for (auto result : map.getResults()) results.push_back(result); } @@ -416,5 +415,5 @@ } AffineMap MutableAffineMap::getAffineMap() const { - return AffineMap::get(numDims, numSymbols, results); + return AffineMap::get(numDims, numSymbols, results, context); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -293,12 +293,11 @@ AffineMap Builder::getConstantAffineMap(int64_t val) { return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, - {getAffineConstantExpr(val)}); + getAffineConstantExpr(val)); } AffineMap Builder::getDimIdentityMap() { - return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, - {getAffineDimExpr(0)}); + return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0)); } AffineMap Builder::getMultiDimIdentityMap(unsigned rank) { @@ -306,18 +305,19 @@ dimExprs.reserve(rank); for (unsigned i = 0; i < rank; ++i) dimExprs.push_back(getAffineDimExpr(i)); - return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs); + return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, + context); } AffineMap Builder::getSymbolIdentityMap() { return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, - {getAffineSymbolExpr(0)}); + getAffineSymbolExpr(0)); } AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) { // expr = d0 + shift. auto expr = getAffineDimExpr(0) + shift; - return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}); + return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); } AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { @@ -325,7 +325,8 @@ shiftedResults.reserve(map.getNumResults()); for (auto resultExpr : map.getResults()) shiftedResults.push_back(resultExpr + shift); - return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults); + return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults, + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -717,10 +717,8 @@ } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, - ArrayRef results) { - // The number of results can't be zero. - assert(!results.empty()); - return getImpl(dimCount, symbolCount, results, results[0].getContext()); + AffineExpr result) { + return getImpl(dimCount, symbolCount, {result}, result.getContext()); } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -723,7 +723,7 @@ simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (expr != simplifiedLayoutExpr) return MemRefType::Builder(t).setAffineMaps({AffineMap::get( - m.getNumDims(), m.getNumSymbols(), {simplifiedLayoutExpr})}); + m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); return MemRefType::Builder(t).setAffineMaps({}); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3132,12 +3132,8 @@ /*allowEmptyList=*/true)) return failure(); // Parsed a valid affine map. - if (exprs.empty()) - map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, - getContext()); - else - map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, - exprs); + map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, + exprs, getContext()); return success(); } @@ -3166,11 +3162,8 @@ if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) return AffineMap(); - if (exprs.empty()) - return AffineMap::get(numDims, numSymbols, getContext()); - // Parsed a valid affine map. - return AffineMap::get(numDims, numSymbols, exprs); + return AffineMap::get(numDims, numSymbols, exprs, getContext()); } /// Parse an affine constraint. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -942,7 +942,8 @@ } auto indexRemap = zeroOffsetCount == rank ? AffineMap() - : AffineMap::get(outerIVs.size() + rank, 0, remapExprs); + : AffineMap::get(outerIVs.size() + rank, 0, remapExprs, + forOp.getContext()); // Replace all users of 'oldMemRef' with 'newMemRef'. LogicalResult res = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -98,8 +98,8 @@ // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); int64_t step = forOp.getStep(); - auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, - {d0.floorDiv(step) % 2}); + auto modTwoMap = + AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2); auto ivModTwoOp = bInner.create(forOp.getLoc(), modTwoMap, forOp.getInductionVar()); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -103,7 +103,8 @@ operands.clear(); operands.push_back(lb); operands.append(bumpValues.begin(), bumpValues.end()); - map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs); + map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs, + b.getContext()); // Simplify the map + operands. fullyComposeAffineMapAndOperands(&map, &operands); map = simplifyAffineMap(map); @@ -485,7 +486,7 @@ if (!forOpIV.use_empty()) { // iv' = iv + 1/2/3...unrollFactor-1; auto d0 = builder.getAffineDimExpr(0); - auto bumpMap = AffineMap::get(1, 0, {d0 + i * step}); + auto bumpMap = AffineMap::get(1, 0, d0 + i * step); auto ivUnroll = builder.create(forOp.getLoc(), bumpMap, forOpIV); operandMap.map(forOpIV, ivUnroll); @@ -616,7 +617,7 @@ if (!forOpIV.use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); - auto bumpMap = AffineMap::get(1, 0, {d0 + i * step}); + auto bumpMap = AffineMap::get(1, 0, d0 + i * step); auto ivUnroll = builder.create(forOp.getLoc(), bumpMap, forOpIV); operandMap.map(forOpIV, ivUnroll); @@ -859,7 +860,8 @@ auto bounds = llvm::to_vector<4>(map->getResults()); bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset); operands->insert(operands->begin() + map->getNumDims(), iv); - *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds); + *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds, + b.getContext()); canonicalizeMapAndOperands(map, operands); } @@ -1514,7 +1516,7 @@ b = forOp.getBodyBuilder(); auto fastBufOffsetMap = - AffineMap::get(lbOperands.size(), 0, {fastBufOffsets[d]}); + AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]); auto offset = b.create(loc, fastBufOffsetMap, lbOperands); // Construct the subscript for the fast memref being copied into/from: @@ -1529,7 +1531,8 @@ memIndices.push_back(forOp.getInductionVar()); } - auto fastBufMap = AffineMap::get(2 * rank, /*symbolCount=*/0, fastBufExprs); + auto fastBufMap = + AffineMap::get(2 * rank, /*symbolCount=*/0, fastBufExprs, b.getContext()); fullyComposeAffineMapAndOperands(&fastBufMap, &fastBufMapOperands); fastBufMap = simplifyAffineMap(fastBufMap); canonicalizeMapAndOperands(&fastBufMap, &fastBufMapOperands); @@ -1837,7 +1840,8 @@ auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i); remapExprs.push_back(dimExpr - fastBufOffsets[i]); } - auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs); + auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs, + b.getContext()); // Record the begin since it may be invalidated by memref replacement. Block::iterator prevOfBegin; diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -271,3 +271,13 @@ // CHECK-NEXT: addi return } + +// ----- + +// CHECK-DAG: #[[MAP_0D:.*]] = affine_map<() -> ()> + +// CHECK-LABEL: func @simplify_zero_dim_map +func @simplify_zero_dim_map(%in : memref) -> f32 { + %out = affine.load %in[] : memref + return %out : f32 +}