diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -96,7 +96,7 @@ struct TiledLinalgOp { LinalgOp op; - SmallVector loops; + SmallVector loops; }; /// Performs standalone tiling of a single LinalgOp by `tileSizes`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -325,10 +325,10 @@ return res; } +template Optional -mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, - OperationFolder *folder) { +tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef permutation, OperationFolder *folder) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of @@ -369,7 +369,13 @@ LinalgOp res = op; SmallVector ivs(loopRanges.size()); auto pivs = makeHandlePointers(MutableArrayRef(ivs)); - LoopNestRangeBuilder(pivs, loopRanges)([&] { + // Convert SubViewOp::Range to linalg_range. + SmallVector linalgRanges; + for (auto &range : loopRanges) { + linalgRanges.push_back( + linalg_range(range.offset, range.size, range.stride)); + } + GenericLoopNestRangeBuilder(pivs, linalgRanges)([&] { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); SmallVector ivValues(ivs.begin(), ivs.end()); @@ -393,7 +399,7 @@ transformIndexedGenericOpIndices(b, res, pivs, loopIndexToRangeIndex); // 5. Gather the newly created loops and return them with the new op. - SmallVector loops; + SmallVector loops; loops.reserve(ivs.size()); for (auto iv : ivs) loops.push_back(loop::getForInductionVarOwner(iv)); @@ -401,9 +407,10 @@ return TiledLinalgOp{res, loops}; } -Optional mlir::linalg::tileLinalgOp( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { +template +Optional +tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef permutation, OperationFolder *folder) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); if (tileSizes.empty()) return llvm::None; @@ -434,9 +441,23 @@ tileSizeValues.push_back(constant_index(folder, 0)); } - return tileLinalgOp(b, op, tileSizeValues, permutation, folder); + return tileLinalgOpImpl(b, op, tileSizeValues, permutation, folder); } +Optional +mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef permutation, + OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, permutation, folder); +} + +Optional mlir::linalg::tileLinalgOp( + OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef permutation, OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, permutation, folder); +} + +template static void tileLinalgOps(FuncOp f, ArrayRef tileSizes) { OpBuilder b(f); OperationFolder folder(f.getContext()); @@ -444,7 +465,7 @@ if (!op.hasBufferSemantics()) return; auto opLoopsPair = - tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder); + tileLinalgOpImpl(b, op, tileSizes, /*permutation=*/{}, &folder); // If tiling occurred successfully, erase old op. if (opLoopsPair) op.erase(); @@ -458,28 +479,31 @@ } namespace { -struct LinalgTilingPass : public FunctionPass { + +template +struct LinalgTilingPass : public FunctionPass> { LinalgTilingPass() = default; - LinalgTilingPass(ArrayRef sizes); + LinalgTilingPass(ArrayRef sizes) { + this->tileSizes.assign(sizes.begin(), sizes.end()); + } - void runOnFunction() override { tileLinalgOps(getFunction(), tileSizes); } + void runOnFunction() override { + tileLinalgOps(this->getFunction(), tileSizes); + } SmallVector tileSizes; }; -} // namespace -LinalgTilingPass::LinalgTilingPass(ArrayRef sizes) { - this->tileSizes.assign(sizes.begin(), sizes.end()); -} +} // namespace std::unique_ptr> mlir::createLinalgTilingPass(ArrayRef tileSizes) { - return std::make_unique(tileSizes); + return std::make_unique>(tileSizes); } -static PassRegistration +static PassRegistration> pass("linalg-tile", "Tile operations in the linalg dialect", [] { - auto pass = std::make_unique(); + auto pass = std::make_unique>(); pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); return pass; });