diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1143,6 +1143,19 @@ LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes); + //========================================================================// + // Implementation of `TilingInterface` for `LinalgOp`s + //========================================================================// + SmallVector getIterationDomain(OpBuilder &b); + SmallVector getTiledImplementation(OpBuilder &b, + ValueRange dest, ArrayRef offsets, + ArrayRef sizes, + bool tileDestOperands); + LogicalResult getResultTilePosition(OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes); + // TODO: Remove once prefixing is flipped. ArrayAttr getIteratorTypes() { return iterator_types(); } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -19,6 +19,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" // Base Tablegen class for Linalg ops. // Linalg ops that correspond to library calls operate on ShapedType as their @@ -30,13 +31,15 @@ DeclareOpInterfaceMethods, LinalgStructuredInterface, RegionBranchOpInterface, - ReifyRankedShapedTypeOpInterface], props)> { + ReifyRankedShapedTypeOpInterface, + TilingInterface], props)> { code structuredOpsBaseDecls = [{ // Return whether the op accesses the iteration indices. bool hasIndexSemantics() { return !this->getBody()->getOps().empty(); } + // Implementation of methods from `ReifyRankedShapedTypeOpInterface`. LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { return llvm::cast(getOperation()).reifyResultShapes(b, @@ -49,6 +52,34 @@ // Op has a region, but conceptually the control flow does not enter the // region. } + + // Implementation of methods from `TilingInterface` for all + // structured ops by redirecting to the implementation in `LinalgInterface`. + SmallVector getDestinationOperands(OpBuilder &b) { + return llvm::cast(getOperation()).getOutputOperands(); + } + SmallVector getLoopIteratorTypes() { + return llvm::to_vector(llvm::map_range(iterator_types(), + [](Attribute strAttr) { + return strAttr.cast().getValue(); + })); + } + SmallVector getIterationDomain(OpBuilder &b) { + return llvm::cast(getOperation()).getIterationDomain(b); + } + SmallVector getTiledImplementation(OpBuilder &b, + ValueRange dest, ArrayRef offsets, + ArrayRef sizes, bool tileDestOperands) { + return llvm::cast(getOperation()).getTiledImplementation(b, + dest, offsets, sizes, tileDestOperands); + } + LogicalResult getResultTilePosition(OpBuilder &b, + unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + return llvm::cast(getOperation()).getResultTilePosition(b, + resultNumber, offsets, sizes, resultOffsets, resultSizes); + } }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/Linalg/Transforms/TileUsingInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/TileUsingInterface.h @@ -0,0 +1,66 @@ +//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TILEUSINGINTERFACE_H +#define MLIR_DIALECT_LINALG_TRANSFORMS_TILEUSINGINTERFACE_H + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +namespace mlir { +class Operation; +class TilingInterface; + +namespace scf { +class ForOp; +} + +} // namespace mlir + +namespace mlir { +struct SCFTilingResult { + Operation *tiledOp; + SmallVector loops; +}; + +/// Pattern to tile an op that implementas the `TilingInterface` using +/// `scf.for` for iterating over the tiles. +struct TileUsingSCFForOp : public OpInterfaceRewritePattern { + /// Construct a generic pattern applied to all TilingInterface ops that verify + /// `filter`. + TileUsingSCFForOp(MLIRContext *context, linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1); + + /// Construct a generic pattern applied to `opName`. + TileUsingSCFForOp(StringRef opName, MLIRContext *context, + linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1); + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr + returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } + +private: + /// LinalgTransformMarker handles special attribute manipulations. + linalg::LinalgTransformationFilter filter; + /// Options to control tiling; + linalg::LinalgTilingOptions options; +}; + +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_TRANSFORMS_TILEUSINGINTERFACE_H \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -164,11 +164,11 @@ SmallVector computeTileOffsets(OpBuilder &b, Location loc, ValueRange ivs, ValueRange tileSizes); -/// Compute tile sizes, given a list of loop `ivs`, `tileSizes` and dimension +/// Compute tile sizes, given a list of `tileSizes` and dimension /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the /// corresponding result size is the corresponding value from `sizeBounds`. /// Note: The returned tile sizes are closed intervals. -SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs, +SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange tileSizes, ArrayRef sizeBounds); diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -56,6 +56,25 @@ ValueRange newIterOperands, const NewYieldValueFn &newYieldValuesFn); +/// Update a perfectly nested loop nest to yield new values from the innermost +/// loop and propagating it up through the loop nest. This function +/// - Expects `loopNest` to be a perfectly nested loop with outer most loop +/// first and innermost loop last. +/// - `newIterOperands` are the initialization values to be used for the +/// outermost loop +/// - `newYielValueFn` is the callback that generates the new values to be +/// yielded from within the innermost loop. +/// - The original loops are not erased, but are left in a "no-op" state where +/// the body of the loop just yields the basic block arguments that correspond +/// to the initialization values of a loop. The original loops are dead after +/// this method. +/// - All uses of the `newIterOperands` within the generated new loop +/// are replaced with the corresponding `BlockArgument` in the loop body. +SmallVector +replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef loopNest, + ValueRange newIterOperands, + NewYieldValueFn newYieldValueFn); + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -98,6 +98,28 @@ /*defaultImplementation=*/[{ return {}; }] + >, + InterfaceMethod< + /*desc=*/[{ + Method to return the position of the result tile computed by the tiled operation. + + Specifies what tile of the result of the original tensor is computed + by the tiled implementation. Expects the same `offsets` and `sizes` as + used to obtain the tiled implementation of the operation. + }], + /*retType=*/"LogicalResult", + /*methodName=*/"getResultTilePosition", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$resultNumber, + "ArrayRef ":$offsets, + "ArrayRef ":$sizes, + "SmallVector &":$resultOffsets, + "SmallVector &":$resultSizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] > ]; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -10,7 +10,9 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExprVisitor.h" @@ -474,17 +476,6 @@ return result; } -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim) { - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - llvm_unreachable("Expected MemRefType or TensorType"); -} - SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, Location loc) { SmallVector res; @@ -608,6 +599,84 @@ return success(); } +//===---------------------------------------------------------------------===// +// Implementation of methods required by `TilingInterface` +//===---------------------------------------------------------------------===// + +SmallVector LinalgOp::getIterationDomain(OpBuilder &b) { + Location loc = getLoc(); + auto allShapesSizes = createFlatListOfOperandDims(b, loc); + AffineMap map = getShapesToLoopsMap(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + return llvm::to_vector(llvm::map_range( + applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) { + return Range{zero, v, one}; + })); +} + +SmallVector LinalgOp::getTiledImplementation( + OpBuilder &b, ValueRange dest, ArrayRef offsets, + ArrayRef sizes, bool tileDestOperands) { + // Leave the `sizeBounds` value empty. That is only needed when the `sizes` + // specified could lead to out of bounds accesses. + Location loc = getLoc(); + SmallVector valuesToTile = getInputAndOutputOperands(); + SmallVector tiledOperands = + makeTiledShapes(b, loc, *this, valuesToTile, + getValueOrCreateConstantIndexOp(b, loc, offsets), + getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); + + SmallVector resultTensorTypes = llvm::to_vector( + llvm::map_range(getOutputTensorOperands(), [&](OpOperand *opOperand) { + return tiledOperands[opOperand->getOperandNumber()].getType(); + })); + + Operation *tiledOp = clone(b, loc, resultTensorTypes, tiledOperands); + + return {tiledOp}; +} + +LogicalResult LinalgOp::getResultTilePosition( + OpBuilder &b, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffset, + SmallVector &resultSizes) { + Location loc = getLoc(); + AffineExpr d0; + bindDims(b.getContext(), d0); + + auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc, + AffineExpr expr, + ValueRange operands) -> Value { + AffineMap map = AffineMap::inferFromExprList({expr}).front(); + SmallVector normalizedOperands(operands.begin(), operands.end()); + mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands); + canonicalizeMapAndOperands(&map, &normalizedOperands); + return builder.createOrFold(loc, map, normalizedOperands); + }; + + SmallVector sizeVals = getValueOrCreateConstantIndexOp(b, loc, sizes); + SmallVector subShapeSizes = + llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) { + return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v); + })); + OpOperand *outOperand = getOutputOperand(resultNumber); + Value sliceOpResult = makeTiledShape( + b, loc, outOperand->get(), sizeVals, getTiedIndexingMap(outOperand), + getValueOrCreateConstantIndexOp(b, loc, offsets), + /*ubs*/ {}, subShapeSizes, true); + auto sliceOp = sliceOpResult.getDefiningOp(); + if (!sliceOp) + return failure(); + resultOffset = sliceOp.getMixedOffsets(); + resultSizes = sliceOp.getMixedSizes(); + return success(); +} + +//===---------------------------------------------------------------------===// +// Structured op verification. +//===---------------------------------------------------------------------===// + LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); // Expect at least one output operand. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -23,6 +23,7 @@ Promotion.cpp SparseTensorRewriting.cpp SplitReduction.cpp + TileUsingInterface.cpp Tiling.cpp Transforms.cpp Vectorization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/TileUsingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/TileUsingInterface.cpp @@ -0,0 +1,241 @@ +//===- Tiling.cpp - Implementation of tiling using TilingInterface +//---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the tiling using TilingInterface. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tile-using-interface" + +using namespace mlir; + +/// Generate an empty loop nest that represents the tiled loop nest shell. +/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. +/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. +/// - In `offsets` and `sizes` return the multi-dimensional offset and size of +/// the +/// tile processed within the inner most loop. +static SmallVector +generateTileLoopNest(OpBuilder &builder, Location loc, + ArrayRef loopRanges, ArrayRef tileSizeVals, + SmallVector &offsets, + SmallVector &sizes) { + assert(!loopRanges.empty() && "expected at least one loop range"); + assert(loopRanges.size() == tileSizeVals.size() && + "expected as many tile sizes as loop ranges"); + OpBuilder::InsertionGuard guard(builder); + SmallVector loops; + offsets.resize(loopRanges.size()); + sizes.resize(loopRanges.size()); + + // The tile size to use (to avoid out of bounds access) is minimum of + // `tileSize` and `ub - iv`, where `iv` is the induction variable + // of the tiled loop. + AffineExpr s0, s1, d0; + bindDims(builder.getContext(), d0); + bindSymbols(builder.getContext(), s0, s1); + AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); + + for (auto loopRange : llvm::enumerate(loopRanges)) { + // No loops if tile size is zero. Set offset and size to the loop + // offset and size. + if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { + offsets[loopRange.index()] = loopRange.value().offset; + sizes[loopRange.index()] = loopRange.value().size; + continue; + } + + auto loop = builder.create( + loc, loopRange.value().offset, loopRange.value().size, + loopRange.value().stride, ValueRange{}, + [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, + ValueRange /*iterArgs*/) { + Value boundedTileSize = builder.create( + bodyLoc, minMap, + ValueRange{iv, tileSizeVals[loopRange.index()], + loopRange.value().size}); + sizes[loopRange.index()] = boundedTileSize; + builder.create(loc); + }); + offsets[loopRange.index()] = loop.getInductionVar(); + loops.push_back(loop); + builder.setInsertionPoint(loop.getBody()->getTerminator()); + } + return loops; +} + +/// Linalg tiling pattern. +TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, + linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter f, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(std::move(f)), options(std::move(options)) {} + +TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, MLIRContext *context, + linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter f, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(f.addOpNameFilter(opName)), options(std::move(options)) {} + +FailureOr +TileUsingSCFForOp::returningMatchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); + + if (!options.tileSizeComputationFunction) + return failure(); + + // 1. Get the range of the loops that are represented by the operation. + SmallVector iterationDomain = op.getIterationDomain(rewriter); + size_t numLoops = iterationDomain.size(); + if (numLoops == 0) + return failure(); + + // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" + // skips tiling a particular dimension. This convention is significantly + // simpler to handle instead of adjusting affine maps to account for missing + // dimensions. + SmallVector tileSizeVector = + options.tileSizeComputationFunction(rewriter, op); + if (tileSizeVector.size() < iterationDomain.size()) { + auto zero = rewriter.create(op.getLoc(), 0); + tileSizeVector.append(numLoops - tileSizeVector.size(), zero); + } + + // 3. Materialize an empty loop nest that iterates over the tiles. These loops + // for now do not return any values even if the original operation has + // results. + SmallVector offsets, sizes; + SmallVector loops = generateTileLoopNest( + rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); + + if (!loops.empty()) { + LLVM_DEBUG({ + llvm::errs() << "LoopNest shell :\n"; + loops.front().dump(); + llvm::errs() << "\n"; + }); + } + + // 4. Generate the tiled implementation within the inner most loop. + if (!loops.empty()) + rewriter.setInsertionPoint(loops.back().getBody()->getTerminator()); + SmallVector tiledImplementation = op.getTiledImplementation( + rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); + if (tiledImplementation.size() != 1) { + return rewriter.notifyMatchFailure( + op, "expected tiled implementation to return a single op"); + } + + if (!loops.empty()) { + LLVM_DEBUG({ + llvm::errs() << "After tiled implementation :\n"; + loops.front().dump(); + llvm::errs() << "\n"; + }); + } + + // 5. If the original operations has results, modify the loop nest to yield + // the replacement values. + if (op->getNumResults() != 0) { + SmallVector replacements; + if (loops.empty()) { + // 5a. If there were no loops, the tiled implementation results are the + // replacements. + replacements.assign(tiledImplementation[0]->result_begin(), + tiledImplementation[0]->result_end()); + } else { + // 5b. `scf.for` with tensor semantics requires the loop nest to yield the + // replacement values using destructive updates. Use the `TilingInterface` + // to get the position of the result tiles and use that to generate the + // destructive update pattern, i.e., + // + // ```mlir + // scf.for %iv0 = ... { + // %0 = tiled_op + // } + // ``` + // + // is transformed to + // + // ```mlir + // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { + // %0 = tiled_op + // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] + // scf.yield %1 + // } + // ``` + NewYieldValueFn yieldValueFn = + [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) -> SmallVector { + SmallVector yieldedValues; + Attribute one = b.getIndexAttr(1); + for (auto resultNum : llvm::seq(0, op->getNumResults())) { + SmallVector resultTileOffsets, resultTileSizes; + if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, + resultTileOffsets, + resultTileSizes))) { + op.emitOpError("unable to get position of result ") + << resultNum << " of the tiled implementation"; + return {}; + } + SmallVector resultTileStrides(resultTileOffsets.size(), + one); + Value yieldedValue = b.create( + op->getLoc(), tiledImplementation[0]->getResult(resultNum), + newBBArgs[resultNum], resultTileOffsets, resultTileSizes, + resultTileStrides); + yieldedValues.push_back(yieldedValue); + } + return yieldedValues; + }; + SmallVector newLoops = replaceLoopNestWithNewYields( + rewriter, loops, op.getDestinationOperands(rewriter), yieldValueFn); + for (auto loop : llvm::enumerate(loops)) { + rewriter.eraseOp(loop.value()); + loops[loop.index()] = cast(newLoops[loop.index()]); + } + scf::ForOp outerMost = loops.front(); + replacements.assign(outerMost->result_begin(), outerMost->result_end()); + + if (!loops.empty()) { + LLVM_DEBUG({ + llvm::errs() << "After yielding values :\n"; + loops.front().dump(); + llvm::errs() << "\n"; + }); + } + } + + rewriter.replaceOp(op, replacements); + } else { + rewriter.eraseOp(op); + } + + SCFTilingResult tilingResult; + tilingResult.tiledOp = tiledImplementation[0]; + tilingResult.loops = std::move(loops); + filter.replaceLinalgTransformationFilter(rewriter, tilingResult.tiledOp); + return tilingResult; +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -320,8 +320,7 @@ // Compute offsets and sizes of ExtractSliceOp. SmallVector offsets = computeTileOffsets(b, loc, localIvs, tileSizes); - SmallVector sizes = - computeTileSizes(b, loc, localIvs, tileSizes, allDims); + SmallVector sizes = computeTileSizes(b, loc, tileSizes, allDims); // Create ExtractSliceOp: Extract a tile from the tensor::PadOp. // Note: The tensor::PadOp is located outside of the loop nest. It is // later moved inside by ExtractSliceOfPadTensorSwapPattern. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -893,7 +893,7 @@ return offsets; } -SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs, +SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange tileSizes, ArrayRef sizeBounds) { SmallVector sizes; @@ -923,7 +923,7 @@ // that define tile subshapes. SmallVector lbs = computeTileOffsets(b, loc, ivs, tileSizes); SmallVector subShapeSizes = - computeTileSizes(b, loc, ivs, tileSizes, sizeBounds); + computeTileSizes(b, loc, tileSizes, sizeBounds); assert(static_cast(valuesToTile.size()) == linalgOp.getNumInputsAndOutputs() && diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; @@ -101,6 +102,31 @@ return newLoop; } +SmallVector mlir::replaceLoopNestWithNewYields( + OpBuilder &builder, ArrayRef loopNest, + ValueRange newIterOperands, NewYieldValueFn newYieldValueFn) { + if (loopNest.empty()) + return {}; + SmallVector newLoopNest(loopNest.size()); + + newLoopNest.back() = replaceLoopWithNewYields( + builder, loopNest.back(), newIterOperands, newYieldValueFn); + + for (unsigned loopDepth : + llvm::reverse(llvm::seq(0, loopNest.size() - 1))) { + NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc, + ArrayRef innerNewBBArgs) { + SmallVector newYields( + newLoopNest[loopDepth + 1]->getResults().take_back( + newIterOperands.size())); + return newYields; + }; + newLoopNest[loopDepth] = replaceLoopWithNewYields( + builder, loopNest[loopDepth], newIterOperands, fn); + } + return newLoopNest; +} + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/test/Dialect/Linalg/tile-using-interface.mlir b/mlir/test/Dialect/Linalg/tile-using-interface.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-using-interface.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s + +func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func.func @simple_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[OUTER:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]]) +// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]] +// CHECK: %[[INNER:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]]) +// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]] +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT1]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT1]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[INNER]] +// CHECK: return %[[OUTER]] diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -6,6 +6,7 @@ TestLinalgHoisting.cpp TestLinalgTransforms.cpp TestPadFusion.cpp + TestTilingInterface.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Linalg/TestTilingInterface.cpp b/mlir/test/lib/Dialect/Linalg/TestTilingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestTilingInterface.cpp @@ -0,0 +1,67 @@ +//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing tiling operations using +// `TilingInterface`. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +namespace { +struct TestTilingInterfacePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) + + TestTilingInterfacePass() = default; + TestTilingInterfacePass(const TestTilingInterfacePass &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-tiling-interface"; } + StringRef getDescription() const final { + return "Test tiling using TilingInterface"; + } + + void runOnOperation() override; +}; +} // namespace + +void TestTilingInterfacePass::runOnOperation() { + MLIRContext *context = &getContext(); + linalg::LinalgTilingOptions tilingOptions; + tilingOptions.setTileSizes({10, 20}); + linalg::LinalgTransformationFilter filter( + StringAttr::get(context, "simple_gemm"), + StringAttr::get(context, "tiled")); + + RewritePatternSet tilingPatterns(context); + tilingPatterns.add(context, tilingOptions, filter); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(tilingPatterns)))) + return signalPassFailure(); +} + +namespace mlir { +namespace test { +void registerTestTilingInterface() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -111,6 +111,7 @@ void registerTestSCFUtilsPass(); void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); +void registerTestTilingInterface(); void registerTestTransformDialectInterpreterPass(); void registerTestVectorLowerings(); } // namespace test @@ -206,6 +207,7 @@ mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms(); + mlir::test::registerTestTilingInterface(); mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings(); }