diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h --- a/mlir/include/mlir/Interfaces/TilingInterface.h +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -21,6 +21,27 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" +namespace mlir { + +struct TilingResult { + SmallVector generatedOperations; + SmallVector tiledResults; +}; + +struct SubsetSpecification { + struct Tile { + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + }; + Tile tile; +}; + +using MaterializeSubsetFn = function_ref; + +} // namespace mlir + /// Include the ODS generated interface header files. #include "mlir/Interfaces/TilingInterface.h.inc" diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -63,16 +63,14 @@ The method returns the operation that is the tiled implementation. }], - /*retType=*/"SmallVector", + /*retType=*/"TilingResult", /*methodName=*/"getTiledImplementation", /*args=*/(ins "OpBuilder &":$b, - "ArrayRef ":$offsets, - "ArrayRef ":$sizes), + "SubsetSpecification ":$subsetSpec, + "MaterializeSubsetFn ":$materializeFn), /*methodBody=*/"", - /*defaultImplementation=*/[{ - return {}; - }] + /*defaultImplementation=*/"return {};" >, InterfaceMethod< /*desc=*/[{ diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -76,6 +76,7 @@ //===----------------------------------------------------------------------===// namespace { + /// External model implementation of TilingInterface for LinalgOps. An external /// model implementation is used for now till the use of `TilingInterface` is /// on-par with the current Linalg tiling + fusion patterns. Once it is @@ -110,26 +111,40 @@ } // Instantiate the tiled implementation of the operation. - SmallVector - getTiledImplementation(Operation *op, OpBuilder &b, - ArrayRef offsets, - ArrayRef sizes) const { + TilingResult getTiledImplementation(Operation *op, OpBuilder &b, + SubsetSpecification subsetSpec, + MaterializeSubsetFn materializeFn) const { // Leave the `sizeBounds` value empty. That is only needed when the `sizes` // specified could lead to out of bounds accesses. Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); SmallVector valuesToTile = linalgOp->getOperands(); - SmallVector tiledOperands = makeTiledShapes( - b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); + + SmallVector> allSliceParameter = + computeAllSliceParameters(b, loc, linalgOp, valuesToTile, + subsetSpec.tile.offsets, + subsetSpec.tile.strides, {}, true); + + SmallVector tiledOperands; + for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { + Value valueToTile = std::get<0>(item); + Optional sliceParams = std::get<1>(item); + tiledOperands.push_back( + sliceParams.has_value() + ? materializeFn(b, loc, valueToTile, + {{sliceParams->offsets, sliceParams->sizes, + sliceParams->strides}}) + : valueToTile); + } SmallVector resultTensorTypes = getTensorOutputTypes(linalgOp, tiledOperands); Operation *tiledOp = linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); - offsetIndices(b, cast(tiledOp), offsets); + offsetIndices(b, cast(tiledOp), subsetSpec.tile.offsets); - return {tiledOp}; + return TilingResult{{tiledOp}, tiledOp->getResults()}; } // Return the details of the output tile generated by the tiled diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -272,6 +272,13 @@ } } +Value MaterializeSlice(OpBuilder &builder, Location loc, Value valueToTile, + const SubsetSpecification &subsetSpec) { + return builder.create( + loc, valueToTile, subsetSpec.tile.offsets, subsetSpec.tile.sizes, + subsetSpec.tile.strides); +} + /// Implementation of tiling transformation of `op` that implements the /// `TilingInterface` using `scf.for` to iterate over the tiles. FailureOr @@ -358,9 +365,14 @@ if (!tilingResult.loops.empty()) rewriter.setInsertionPoint( tilingResult.loops.back().getBody()->getTerminator()); - SmallVector tiledImplementation = - op.getTiledImplementation(rewriter, offsets, sizes); - tilingResult.tiledOps.append(tiledImplementation); + SubsetSpecification subsetSpec; + subsetSpec.tile = { + offsets, sizes, + SmallVector(offsets.size(), rewriter.getIndexAttr(1))}; + + TilingResult tiledImplementationResult = + op.getTiledImplementation(rewriter, subsetSpec, MaterializeSlice); + tilingResult.tiledOps.append(tiledImplementationResult.generatedOperations); if (op->getNumResults() == 0) { // nothing more to do. return tilingResult; @@ -369,9 +381,7 @@ // If loops are empty, the tiled op is used as the replacement for the untiled // op. if (tilingResult.loops.empty()) { - tilingResult.replacements = llvm::to_vector( - llvm::map_range(tiledImplementation[0]->getResults(), - [](OpResult result) -> Value { return result; })); + tilingResult.replacements = tiledImplementationResult.tiledResults; return tilingResult; }