diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -58,42 +58,29 @@ : SmallVector(ivs.begin(), ivs.end()); } -// Creates a number of ranges equal to the number of results in `map`. -// The returned ranges correspond to the loop ranges, in the proper order, for -// which new loops will be created. -static SmallVector -emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes) { - // Apply `map` to get view sizes in loop order. - auto sizes = applyMapToValues(b, loc, map, allViewSizes); - // Create a new range with the applied tile sizes. - ScopedContext scope(b, loc); - SmallVector res; - for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { - res.push_back(SubViewOp::Range{std_constant_index(0), sizes[idx], - std_constant_index(1)}); - } - return res; -} - /// Creates a number of ranges equal to the number of dimensions in the `map`. +/// The returned ranges correspond to the loop ranges, in the proper order, for +/// which new loops will be created. /// The function supports for now only limited number of expressions inside /// map results. It expects a non-inverted, concatenated map and last values in -/// allViewSizes will be applied to the symbols in the map. -static SmallVector -emitLoopRangesWithSymbols(OpBuilder &b, Location loc, AffineMap map, - ValueRange allViewSizes) { - assert(allViewSizes.size() == map.getNumInputs() && - "Number of provided values must match number of inputs to the map."); +/// allViewSizes will be applied to the symbols in the map if it contains some. +static SmallVector emitLoopRanges(OpBuilder &b, + Location loc, + AffineMap map, + ValueRange viewSizes) { + assert(viewSizes.size() == + std::max(map.getNumInputs(), map.getNumResults()) && + "Number of provided values must match number of inputs to the map " + "or number of results in case map has more results than inputs."); SmallVector res(map.getNumDims()); - for (auto result : map.getResults()) { + for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { + auto result = map.getResult(idx); if (auto d = result.dyn_cast()) { if (res[d.getPosition()].offset) continue; - res[d.getPosition()] = - SubViewOp::Range{std_constant_index(0), allViewSizes[d.getPosition()], - std_constant_index(1)}; + res[d.getPosition()] = SubViewOp::Range{ + std_constant_index(0), viewSizes[idx], std_constant_index(1)}; } if (auto binOp = result.dyn_cast()) { @@ -103,14 +90,17 @@ continue; auto m = lhs.getLHS().dyn_cast(); - if (!m) + auto n = lhs.getRHS().dyn_cast(); + auto fDiv = rhs.getLHS().dyn_cast(); + if (!m || !n || !fDiv || fDiv.getKind() != AffineExprKind::FloorDiv) continue; + // m + n - s floordiv 2 int mPos = m.getPosition(); AffineMap fromMap = - AffineMap::get(map.getNumDims(), map.getNumSymbols(), rhs.getLHS()); - auto from = applyMapToValues(b, loc, fromMap, allViewSizes).front(); - auto to = b.create(loc, allViewSizes[mPos], from); + AffineMap::get(map.getNumDims(), map.getNumSymbols(), fDiv); + auto from = applyMapToValues(b, loc, fromMap, viewSizes).front(); + auto to = b.create(loc, viewSizes[mPos], from); res[mPos] = SubViewOp::Range{from, to, std_constant_index(1)}; } } @@ -510,35 +500,25 @@ llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); auto map = concatAffineMaps(maps); SmallVector loopRanges; + SmallVector sizes = getViewSizes(builder, linalgOp); auto attr = linalgOp.template getAttrOfType("symbol_source"); if (attr) { - // This map has symbols and thus is not a permutation. Therefore we - // cannot invert it. - unsigned symbolSource = attr.getInt(); - auto sizes = getViewSizes(builder, linalgOp); - unsigned numIn = map.getNumInputs(), numDims = map.getNumDims(); - unsigned diff = numIn - numDims; - + // Find the correct position for inserting values for symbols + unsigned numDims = map.getNumDims(), symbolsPos = 0; + for (unsigned idx = 0; idx < attr.getInt(); idx++) + symbolsPos += linalgOp.getOperand(idx) + .getType() + .template cast() + .getRank(); // Append or rewrite the end of the value list that corresponds to the - // symbols. They are in this case dims of the "symbol_source" operand. - sizes.resize(numIn); - for (unsigned idx = 0; idx < diff; idx++) - sizes[numDims + idx] = sizes[diff * symbolSource + idx]; - loopRanges = emitLoopRangesWithSymbols(scope.getBuilderRef(), - scope.getLocation(), map, sizes); - } else { - AffineMap invertedMap = inversePermutation(map); - if (!invertedMap) - return {}; - if (invertedMap.isEmpty()) { - emitScalarImplementation({}, linalgOp); - return LinalgLoops(); - } - - loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), - invertedMap, getViewSizes(builder, linalgOp)); + // values mapping to symbols. + sizes.resize(map.getNumInputs()); + for (unsigned idx = 0; idx < map.getNumSymbols(); idx++) + sizes[numDims + idx] = sizes[symbolsPos + idx]; } + loopRanges = + emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), map, sizes); SmallVector allIvs; GenerateLoopNest::doit( loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) {