diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -58,33 +58,21 @@ : SmallVector(ivs.begin(), ivs.end()); } -// Creates a number of ranges equal to the number of results in `map`. -// The returned ranges correspond to the loop ranges, in the proper order, for -// which new loops will be created. -static SmallVector -emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes) { - // Apply `map` to get view sizes in loop order. - auto sizes = applyMapToValues(b, loc, map, allViewSizes); - // Create a new range with the applied tile sizes. - ScopedContext scope(b, loc); - SmallVector res; - for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { - res.push_back(SubViewOp::Range{std_constant_index(0), sizes[idx], - std_constant_index(1)}); - } - return res; -} - /// Creates a number of ranges equal to the number of dimensions in the `map`. -/// The function supports for now only limited number of expressions inside -/// map results. It expects a non-inverted, concatenated map and last values in -/// allViewSizes will be applied to the symbols in the map. -static SmallVector -emitLoopRangesWithSymbols(OpBuilder &b, Location loc, AffineMap map, - ValueRange allViewSizes) { - assert(allViewSizes.size() == map.getNumInputs() && - "Number of provided values must match number of inputs to the map."); +/// The returned ranges correspond to the loop ranges, in the proper order, for +/// which new loops will be created. +/// The function supports only maps that are invertible and have results of type +/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr). +/// It expects a non-inverted, concatenated map and last values in +/// allViewSizes will be applied to the symbols in the map if it contains any. +static SmallVector emitLoopRanges(OpBuilder &b, + Location loc, + AffineMap map, + ValueRange viewSizes) { + assert(viewSizes.size() == + std::max(map.getNumInputs(), map.getNumResults()) && + "Number of provided values must match number of inputs to the map " + "or number of results in case map has more results than inputs."); SmallVector res(map.getNumDims()); for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { @@ -93,7 +81,7 @@ if (res[d.getPosition()].offset) continue; res[d.getPosition()] = SubViewOp::Range{ - std_constant_index(0), allViewSizes[idx], std_constant_index(1)}; + std_constant_index(0), viewSizes[idx], std_constant_index(1)}; } // m + n - s floordiv 2 @@ -114,8 +102,8 @@ int mPos = m.getPosition(); AffineMap fromMap = AffineMap::get(map.getNumDims(), map.getNumSymbols(), fDiv); - Value from = applyMapToValues(b, loc, fromMap, allViewSizes).front(); - auto to = b.create(loc, allViewSizes[mPos], from); + Value from = applyMapToValues(b, loc, fromMap, viewSizes).front(); + auto to = b.create(loc, viewSizes[mPos], from); res[mPos] = SubViewOp::Range{from, to, std_constant_index(1)}; } } @@ -512,23 +500,8 @@ SmallVector maps = getIndexingMaps(linalgOp); SmallVector sizes = getViewSizes(builder, linalgOp); AffineMap map = concatAffineMaps(maps); - SmallVector loopRanges; - - if (map.getNumSymbols()) { - loopRanges = emitLoopRangesWithSymbols(scope.getBuilderRef(), - scope.getLocation(), map, sizes); - } else { - AffineMap invertedMap = inversePermutation(map); - if (!invertedMap) - return {}; - if (invertedMap.isEmpty()) { - emitScalarImplementation({}, linalgOp); - return LinalgLoops(); - } - - loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), - invertedMap, sizes); - } + auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), + map, getViewSizes(builder, linalgOp)); SmallVector allIvs; GenerateLoopNest::doit( loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) {