diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -105,6 +105,31 @@ return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); } +/// Returns the bounded tile size given the current `iv`, `loopRange` and +/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. +static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, + Range loopRange, Value iv, + Value tileSize) { + Optional ts = getConstantIntValue(tileSize); + if (ts && ts.value() == 1) + return getAsOpFoldResult(tileSize); + + if (tileDividesIterationDomain( + Range{loopRange.offset, loopRange.size, tileSize})) + return tileSize; + + // The tile size to use (to avoid out of bounds access) is minimum of + // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled + // loop. + AffineExpr s0, s1, d0; + bindDims(b.getContext(), d0); + bindSymbols(b.getContext(), s0, s1); + AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); + Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); + return b.create(loc, minMap, ValueRange{iv, tileSize, size}) + .getResult(); +} + /// Generate an empty loop nest that represents the tiled loop nest shell. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. @@ -124,41 +149,26 @@ offsets.resize(loopRanges.size()); sizes.resize(loopRanges.size()); - // The tile size to use (to avoid out of bounds access) is minimum of - // `tileSize` and `ub - iv`, where `iv` is the induction variable - // of the tiled loop. - AffineExpr s0, s1, d0; - bindDims(builder.getContext(), d0); - bindSymbols(builder.getContext(), s0, s1); - AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); - for (auto loopRange : llvm::enumerate(loopRanges)) { Value offset = getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); Value size = getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); + Value tileSize = tileSizeVals[loopRange.index()]; // No loops if tile size is zero. Set offset and size to the loop // offset and size. - if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { + if (matchPattern(tileSize, m_Zero())) { offsets[loopRange.index()] = offset; sizes[loopRange.index()] = size; continue; } auto loop = builder.create( - loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{}, + loc, offset, size, tileSize, ValueRange{}, [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, ValueRange /*iterArgs*/) { - bool canAvoidMap = tileDividesIterationDomain( - Range{loopRange.value().offset, loopRange.value().size, - tileSizeVals[loopRange.index()]}); - Value boundedTileSize = - (canAvoidMap) - ? tileSizeVals[loopRange.index()] - : builder.create( - bodyLoc, minMap, - ValueRange{iv, tileSizeVals[loopRange.index()], size}); - sizes[loopRange.index()] = boundedTileSize; + sizes[loopRange.index()] = getBoundedTileSize( + bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); builder.create(loc); }); offsets[loopRange.index()] = loop.getInductionVar();