diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- TilingInterfaceImpl.h - Implementation of TilingInterface ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H +#define MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerTilingInterfaceExternalModels(DialectRegistry ®istry); +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -164,11 +164,11 @@ SmallVector computeTileOffsets(OpBuilder &b, Location loc, ValueRange ivs, ValueRange tileSizes); -/// Compute tile sizes, given a list of loop `ivs`, `tileSizes` and dimension +/// Compute tile sizes, given a list of `tileSizes` and dimension /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the /// corresponding result size is the corresponding value from `sizeBounds`. /// Note: The returned tile sizes are closed intervals. -SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs, +SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange tileSizes, ArrayRef sizeBounds); diff --git a/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h @@ -0,0 +1,87 @@ +//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H +#define MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/TilingInterface.h" + +namespace mlir { +class Operation; +class PatternRewriter; +class TilingInterface; +} // namespace mlir + +namespace mlir { +namespace scf { + +using SCFTileSizeComputationFunction = + std::function(OpBuilder &, Operation *)>; + +/// Options to use to control tiling. +struct SCFTilingOptions { + /// Computation function that returns the tile sizes for each operation. + /// Delayed construction of constant tile sizes should occur to interoperate + /// with folding. + SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr; + + SCFTilingOptions & + setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) { + tileSizeComputationFunction = std::move(fun); + return *this; + } + /// Set the `tileSizeComputationFunction` to return the values `ts`. The + /// values must not fold away when tiling. Otherwise, use a more robust + /// `tileSizeComputationFunction`. + SCFTilingOptions &setTileSizes(const SmallVector &ts) { + tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; + return *this; + } + /// Convenience function to set the `tileSizeComputationFunction` to a + /// function that computes tile sizes at the point they are needed. Allows + /// proper interaction with folding. + SCFTilingOptions &setTileSizes(ArrayRef ts); +}; + +struct SCFTilingResult { + Operation *tiledOp; + SmallVector loops; +}; + +/// Pattern to tile an op that implementas the `TilingInterface` using +/// `scf.for` for iterating over the tiles. +struct TileUsingSCFForOp : public OpInterfaceRewritePattern { + /// Construct a generic pattern applied to all TilingInterface ops. + TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options, + PatternBenefit benefit = 1); + + /// Construct a generic pattern applied to `opName`. + TileUsingSCFForOp(StringRef opName, MLIRContext *context, + SCFTilingOptions options, PatternBenefit benefit = 1); + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr + returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } + +private: + /// Options to control tiling; + SCFTilingOptions options; +}; + +} // namespace scf +} // namespace mlir + +#endif // MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SCF_UTILS_UTILS_H_ #define MLIR_DIALECT_SCF_UTILS_UTILS_H_ +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -32,12 +33,6 @@ class FuncOp; } // namespace func -namespace scf { -class IfOp; -class ForOp; -class ParallelOp; -} // namespace scf - /// Replace the `loop` with `newIterOperands` added as new initialization /// values. `newYieldValuesFn` is a callback that can be used to specify /// the additional values to be yielded by the loop. The number of @@ -57,6 +52,25 @@ ValueRange newIterOperands, const NewYieldValueFn &newYieldValuesFn); +/// Update a perfectly nested loop nest to yield new values from the innermost +/// loop and propagating it up through the loop nest. This function +/// - Expects `loopNest` to be a perfectly nested loop with outer most loop +/// first and innermost loop last. +/// - `newIterOperands` are the initialization values to be used for the +/// outermost loop +/// - `newYielValueFn` is the callback that generates the new values to be +/// yielded from within the innermost loop. +/// - The original loops are not erased, but are left in a "no-op" state where +/// the body of the loop just yields the basic block arguments that correspond +/// to the initialization values of a loop. The original loops are dead after +/// this method. +/// - All uses of the `newIterOperands` within the generated new loop +/// are replaced with the corresponding `BlockArgument` in the loop body. +SmallVector +replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef loopNest, + ValueRange newIterOperands, + NewYieldValueFn newYieldValueFn); + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -98,6 +98,28 @@ /*defaultImplementation=*/[{ return {}; }] + >, + InterfaceMethod< + /*desc=*/[{ + Method to return the position of the result tile computed by the tiled operation. + + Specifies what tile of the result of the original tensor is computed + by the tiled implementation. Expects the same `offsets` and `sizes` as + used to obtain the tiled implementation of the operation. + }], + /*retType=*/"LogicalResult", + /*methodName=*/"getResultTilePosition", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$resultNumber, + "ArrayRef ":$offsets, + "ArrayRef ":$sizes, + "SmallVector &":$resultOffsets, + "SmallVector &":$resultSizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] > ]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ SparseTensorRewriting.cpp SplitReduction.cpp Tiling.cpp + TilingInterfaceImpl.cpp Transforms.cpp Vectorization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -320,8 +320,7 @@ // Compute offsets and sizes of ExtractSliceOp. SmallVector offsets = computeTileOffsets(b, loc, localIvs, tileSizes); - SmallVector sizes = - computeTileSizes(b, loc, localIvs, tileSizes, allDims); + SmallVector sizes = computeTileSizes(b, loc, tileSizes, allDims); // Create ExtractSliceOp: Extract a tile from the tensor::PadOp. // Note: The tensor::PadOp is located outside of the loop nest. It is // later moved inside by ExtractSliceOfPadTensorSwapPattern. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -0,0 +1,158 @@ +//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/TilingInterface.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +/// External model implementation of TilingInterface for LinalgOps. An external +/// model implementation is used for now till the use of `TilingInterface` is +/// on-par with the current Linalg tiling + fusion patterns. Once it is +/// maybe possible to move this into the op-definition (though there are +/// advantages to leaving it as an external model) +template +struct LinalgOpTilingInterface + : public TilingInterface::ExternalModel, + LinalgOpTy> { + + /// Return the destination operands. + SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { + return llvm::cast(op).getOutputOperands(); + } + + /// Return the loop iterator type. + SmallVector getLoopIteratorTypes(Operation *op) const { + LinalgOpTy concreteOp = cast(op); + return llvm::to_vector( + llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { + return strAttr.cast().getValue(); + })); + } + + /// Return the iteration domain range. + SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); + AffineMap map = linalgOp.getShapesToLoopsMap(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + return llvm::to_vector(llvm::map_range( + applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) { + return Range{zero, v, one}; + })); + } + + // Instantiate the tiled implementation of the operation. + SmallVector + getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest, + ArrayRef offsets, + ArrayRef sizes, + bool tileDestOperands) const { + // Leave the `sizeBounds` value empty. That is only needed when the `sizes` + // specified could lead to out of bounds accesses. + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector tiledOperands = makeTiledShapes( + b, loc, linalgOp, valuesToTile, + getValueOrCreateConstantIndexOp(b, loc, offsets), + getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); + + SmallVector resultTensorTypes = llvm::to_vector(llvm::map_range( + linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { + return tiledOperands[opOperand->getOperandNumber()].getType(); + })); + + Operation *tiledOp = + linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); + + return {tiledOp}; + } + + // Return the details of the output tile generated by the tiled + // implementation. + LogicalResult + getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) const { + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + + AffineExpr d0; + bindDims(b.getContext(), d0); + + auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc, + AffineExpr expr, + ValueRange operands) -> Value { + AffineMap map = AffineMap::inferFromExprList({expr}).front(); + SmallVector normalizedOperands(operands.begin(), operands.end()); + mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands); + canonicalizeMapAndOperands(&map, &normalizedOperands); + return builder.createOrFold(loc, map, normalizedOperands); + }; + + SmallVector sizeVals = + getValueOrCreateConstantIndexOp(b, loc, sizes); + SmallVector subShapeSizes = + llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) { + return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v); + })); + OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); + Value sliceOpResult = + makeTiledShape(b, loc, outOperand->get(), sizeVals, + linalgOp.getTiedIndexingMap(outOperand), + getValueOrCreateConstantIndexOp(b, loc, offsets), + /*ubs*/ {}, subShapeSizes, true); + auto sliceOp = sliceOpResult.getDefiningOp(); + if (!sliceOp) + return failure(); + resultOffsets = sliceOp.getMixedOffsets(); + resultSizes = sliceOp.getMixedSizes(); + return success(); + } +}; + +} // namespace + +template +static void registerOne(MLIRContext *ctx) { + OpType::template attachInterface>(*ctx); +} + +/// Variadic helper function. +template +static void registerAll(MLIRContext *ctx) { + // FIXME: In c++17 this can be simplified by using 'fold expressions'. + (void)std::initializer_list{0, (registerOne(ctx), 0)...}; +} + +#define GET_OP_LIST + +void mlir::linalg::registerTilingInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { + registerOne(ctx); + registerAll< +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(ctx); + }); +} diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -893,7 +893,7 @@ return offsets; } -SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs, +SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange tileSizes, ArrayRef sizeBounds) { SmallVector sizes; @@ -923,7 +923,7 @@ // that define tile subshapes. SmallVector lbs = computeTileOffsets(b, loc, ivs, tileSizes); SmallVector subShapeSizes = - computeTileSizes(b, loc, ivs, tileSizes, sizeBounds); + computeTileSizes(b, loc, tileSizes, sizeBounds); assert(static_cast(valuesToTile.size()) == linalgOp.getNumInputsAndOutputs() && diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ ParallelLoopFusion.cpp ParallelLoopTiling.cpp StructuralTypeConversions.cpp + TileUsingInterface.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -0,0 +1,249 @@ +//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the tiling using TilingInterface. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/TileUsingInterface.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tile-using-interface" + +using namespace mlir; + +scf::SCFTilingOptions & +scf::SCFTilingOptions::setTileSizes(ArrayRef ts) { + assert(!tileSizeComputationFunction && "tile sizes already set"); + SmallVector tileSizes(ts.begin(), ts.end()); + tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart( + &op->getParentOfType().getBody().front()); + return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { + Value v = b.create(op->getLoc(), s); + return v; + })); + }; + return *this; +} + +/// Generate an empty loop nest that represents the tiled loop nest shell. +/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. +/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. +/// - In `offsets` and `sizes` return the multi-dimensional offset and size of +/// the +/// tile processed within the inner most loop. +static SmallVector +generateTileLoopNest(OpBuilder &builder, Location loc, + ArrayRef loopRanges, ArrayRef tileSizeVals, + SmallVector &offsets, + SmallVector &sizes) { + assert(!loopRanges.empty() && "expected at least one loop range"); + assert(loopRanges.size() == tileSizeVals.size() && + "expected as many tile sizes as loop ranges"); + OpBuilder::InsertionGuard guard(builder); + SmallVector loops; + offsets.resize(loopRanges.size()); + sizes.resize(loopRanges.size()); + + // The tile size to use (to avoid out of bounds access) is minimum of + // `tileSize` and `ub - iv`, where `iv` is the induction variable + // of the tiled loop. + AffineExpr s0, s1, d0; + bindDims(builder.getContext(), d0); + bindSymbols(builder.getContext(), s0, s1); + AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); + + for (auto loopRange : llvm::enumerate(loopRanges)) { + // No loops if tile size is zero. Set offset and size to the loop + // offset and size. + if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { + offsets[loopRange.index()] = loopRange.value().offset; + sizes[loopRange.index()] = loopRange.value().size; + continue; + } + + auto loop = builder.create( + loc, loopRange.value().offset, loopRange.value().size, + tileSizeVals[loopRange.index()], ValueRange{}, + [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, + ValueRange /*iterArgs*/) { + Value boundedTileSize = builder.create( + bodyLoc, minMap, + ValueRange{iv, tileSizeVals[loopRange.index()], + loopRange.value().size}); + sizes[loopRange.index()] = boundedTileSize; + builder.create(loc); + }); + offsets[loopRange.index()] = loop.getInductionVar(); + loops.push_back(loop); + builder.setInsertionPoint(loop.getBody()->getTerminator()); + } + return loops; +} + +scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, + scf::SCFTilingOptions options, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)) {} + +scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, + MLIRContext *context, + scf::SCFTilingOptions options, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)) {} + +FailureOr +scf::TileUsingSCFForOp::returningMatchAndRewrite( + TilingInterface op, PatternRewriter &rewriter) const { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); + + if (!options.tileSizeComputationFunction) { + return rewriter.notifyMatchFailure( + op, "missing tile size computation function"); + } + + // 1. Get the range of the loops that are represented by the operation. + SmallVector iterationDomain = op.getIterationDomain(rewriter); + size_t numLoops = iterationDomain.size(); + if (numLoops == 0) { + return rewriter.notifyMatchFailure( + op, "unable to tile op with no iteration domain"); + } + + // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" + // skips tiling a particular dimension. This convention is significantly + // simpler to handle instead of adjusting affine maps to account for missing + // dimensions. + SmallVector tileSizeVector = + options.tileSizeComputationFunction(rewriter, op); + if (tileSizeVector.size() < iterationDomain.size()) { + auto zero = rewriter.create(op.getLoc(), 0); + tileSizeVector.append(numLoops - tileSizeVector.size(), zero); + } + + scf::SCFTilingResult tilingResult; + SmallVector offsets, sizes; + { + // 3. Materialize an empty loop nest that iterates over the tiles. These + // loops for now do not return any values even if the original operation has + // results. + tilingResult.loops = generateTileLoopNest( + rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); + + LLVM_DEBUG({ + if (!tilingResult.loops.empty()) { + llvm::errs() << "LoopNest shell :\n"; + tilingResult.loops.front().dump(); + llvm::errs() << "\n"; + } + }); + + // 4. Generate the tiled implementation within the inner most loop. + if (!tilingResult.loops.empty()) + rewriter.setInsertionPoint( + tilingResult.loops.back().getBody()->getTerminator()); + SmallVector tiledImplementation = op.getTiledImplementation( + rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); + if (tiledImplementation.size() != 1) { + return rewriter.notifyMatchFailure( + op, "expected tiled implementation to return a single op"); + } + tilingResult.tiledOp = tiledImplementation[0]; + + LLVM_DEBUG({ + if (!tilingResult.loops.empty()) { + llvm::errs() << "After tiled implementation :\n"; + tilingResult.loops.front().dump(); + llvm::errs() << "\n"; + } + }); + } + + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + return tilingResult; + } + + // 5. If the original operations has results, modify the loop nest to yield + // the replacement values. + SmallVector replacements; + if (tilingResult.loops.empty()) { + // 5a. If there were no loops, the tiled implementation results are the + // replacements. + rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); + return tilingResult; + } + + // 5b. `scf.for` with tensor semantics requires the loop nest to yield the + // replacement values using destructive updates. Use the `TilingInterface` + // to get the position of the result tiles and use that to generate the + // destructive update pattern, i.e., + // + // ```mlir + // scf.for %iv0 = ... { + // %0 = tiled_op + // } + // ``` + // + // is transformed to + // + // ```mlir + // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { + // %0 = tiled_op + // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] + // scf.yield %1 + // } + // ``` + NewYieldValueFn yieldValueFn = + [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) -> SmallVector { + SmallVector yieldedValues; + Attribute one = b.getIndexAttr(1); + for (auto resultNum : llvm::seq(0, op->getNumResults())) { + SmallVector resultTileOffsets, resultTileSizes; + if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, + resultTileOffsets, + resultTileSizes))) { + op.emitOpError("unable to get position of result ") + << resultNum << " of the tiled implementation"; + return {}; + } + SmallVector resultTileStrides(resultTileOffsets.size(), + one); + Value yieldedValue = b.create( + op->getLoc(), tilingResult.tiledOp->getResult(resultNum), + newBBArgs[resultNum], resultTileOffsets, resultTileSizes, + resultTileStrides); + yieldedValues.push_back(yieldedValue); + } + return yieldedValues; + }; + SmallVector newLoops = replaceLoopNestWithNewYields( + rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), + yieldValueFn); + for (auto loop : llvm::enumerate(tilingResult.loops)) { + rewriter.eraseOp(loop.value()); + tilingResult.loops[loop.index()] = newLoops[loop.index()]; + } + rewriter.replaceOp(op, tilingResult.loops.front().getResults()); + return tilingResult; +} diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; @@ -101,6 +102,31 @@ return newLoop; } +SmallVector mlir::replaceLoopNestWithNewYields( + OpBuilder &builder, ArrayRef loopNest, + ValueRange newIterOperands, NewYieldValueFn newYieldValueFn) { + if (loopNest.empty()) + return {}; + SmallVector newLoopNest(loopNest.size()); + + newLoopNest.back() = replaceLoopWithNewYields( + builder, loopNest.back(), newIterOperands, newYieldValueFn); + + for (unsigned loopDepth : + llvm::reverse(llvm::seq(0, loopNest.size() - 1))) { + NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc, + ArrayRef innerNewBBArgs) { + SmallVector newYields( + newLoopNest[loopDepth + 1]->getResults().take_back( + newIterOperands.size())); + return newYields; + }; + newLoopNest[loopDepth] = replaceLoopWithNewYields( + builder, loopNest[loopDepth], newIterOperands, fn); + } + return newLoopNest; +} + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -0,0 +1,194 @@ +// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s + +func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func.func @simple_matmul( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]] +// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]]) +// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]] +// CHECK: %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]]) +// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]] +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT1]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT1]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[INNER]] +// CHECK: return %[[OUTER]] + +// ----- + +func.func @simple_matmul_memref(%arg0 : memref, %arg1 : memref, + %arg2 : memref) { + linalg.matmul {__internal_linalg_transform__ = "simple_gemm_memref"} + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> +// CHECK: func.func @simple_matmul_memref( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index +// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]] +// CHECK: %[[TS_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]] +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]] +// CHECK: %[[TS_N:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]] +// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]] +// CHECK: %[[TS_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[K]]] +// CHECK-DAG: %[[LHS_TILE:.+]] = memref.subview %[[ARG0]] +// CHECK-SAME: [%[[IV0]], %[[IV2]]] [%[[TS_M]], %[[TS_K]]] [1, 1] +// CHECK-DAG: %[[RHS_TILE:.+]] = memref.subview %[[ARG1]] +// CHECK-SAME: [%[[IV2]], %[[IV1]]] [%[[TS_K]], %[[TS_N]]] [1, 1] +// CHECK-DAG: %[[OUT_TILE:.+]] = memref.subview %[[ARG2]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_M]], %[[TS_N]]] [1, 1] +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[OUT_TILE]] : + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) { + %init0 = linalg.init_tensor [128, 300, 200] : tensor<128x300x200xf32> + %init1 = linalg.init_tensor [300, 128, 200] : tensor<300x128x200xf32> + %0:2 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel"]} + {__internal_linalg_transform__ = "parallel_generic_transpose"} + ins(%arg0 : tensor<128x200x300xf32>) + outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + linalg.yield %b0, %b0 : f32, f32 + } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) + return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func.func @multi_result( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index +// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200] +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [300, 128, 200] +// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]] +// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]]) +// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]] +// CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]] +// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]]) +// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C300]]] +// CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, %[[TS_X]]] [1, 1, 1] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]] +// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]] +// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1] +// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG_TILE]] : +// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] : +// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]] +// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1] +// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]] +// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1] +// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]] +// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1 +// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1 + +// ----- + +func.func @conv2D(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf { + strides = dense<[2, 3]> : tensor<2xi64>, + dilation = dense<[4, 5]> : tensor<2xi64>, + __internal_linalg_transform__ = "simple_conv"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)> +// CHECK: func.func @conv2D( +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[FILTER:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]] +// CHECK-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]] +// CHECK-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]] +// CHECK-DAG: %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]] +// CHECK-DAG: %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]] +// CHECK-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]] +// CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]] +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]] +// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[INIT]]) +// CHECK: %[[TS_P:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[P]]] +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]]) +// CHECK: %[[TS_Q:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[Q]]] +// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]] +// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT1]]) +// CHECK-DAG: %[[TS_C:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[C]]] +// CHECK-DAG: %[[TS_H:.+]] = affine.apply #[[MAP3]](%[[TS_P]])[%[[R]]] +// CHECK-DAG: %[[TS_W:.+]] = affine.apply #[[MAP4]](%[[TS_Q]])[%[[S]]] +// CHECK-DAG: %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]] +// CHECK-SAME: [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]] +// CHECK-DAG: %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]] +// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT2]] +// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] +// CHECK: %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64> +// CHECK-SAME: ins(%[[INPUT_TILE]], %[[FILTER_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]] +// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] diff --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt --- a/mlir/test/lib/CMakeLists.txt +++ b/mlir/test/lib/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(Analysis) add_subdirectory(Conversion) add_subdirectory(Dialect) +add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Pass) add_subdirectory(Reducer) diff --git a/mlir/test/lib/Interfaces/CMakeLists.txt b/mlir/test/lib/Interfaces/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TilingInterface) diff --git a/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_library(MLIRTilingInterfaceTestPasses + TestTilingInterface.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRAffine + MLIRArithmetic + MLIRLinalg + MLIRLinalgTransforms + MLIRMemRef + MLIRSCF + MLIRSCFTransforms + MLIRTensor + ) diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -0,0 +1,126 @@ +//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing tiling operations using +// `TilingInterface`. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +namespace { + +/// Construct a generic pattern applied to all TilingInterface ops that verify +/// `filter`. +struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { + TestTileUsingSCFForOpWithFilter(MLIRContext *context, + scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} + + /// Construct a generic pattern applied to `opName`. + TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context, + scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + FailureOr tilingResult = + returningMatchAndRewrite(op, rewriter); + if (failed(tilingResult)) { + return failure(); + } + filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); + return success(); + } + +private: + linalg::LinalgTransformationFilter filter; +}; + +struct TestTilingInterfacePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) + + TestTilingInterfacePass() = default; + TestTilingInterfacePass(const TestTilingInterfacePass &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + linalg::registerTilingInterfaceExternalModels(registry); + } + StringRef getArgument() const final { return "test-tiling-interface"; } + StringRef getDescription() const final { + return "Test tiling using TilingInterface"; + } + + void runOnOperation() override; +}; +} // namespace + +static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { + auto addPatternForTiling = [&](ArrayRef tileSizes, + StringRef filterName) { + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes); + linalg::LinalgTransformationFilter filter( + StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); + patterns.add(context, tilingOptions, + filter); + }; + // 1. Tiling M and N dims of `linalg.matmul` on tensors. + addPatternForTiling({10, 20}, "simple_gemm"); + // 2. Tiling M, N and K of `linalg.matmul` on buffers. + addPatternForTiling({10, 20, 30}, "simple_gemm_memref"); + // 3. Tiling 3D parallel generic op which implements a transpose + addPatternForTiling({10, 0, 20}, "parallel_generic_transpose"); + // 4. Tiling 2D conv op. + addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv"); +} + +void TestTilingInterfacePass::runOnOperation() { + MLIRContext *context = &getContext(); + + RewritePatternSet tilingPatterns(context); + addTestPatterns(context, tilingPatterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(tilingPatterns)))) + return signalPassFailure(); +} + +namespace mlir { +namespace test { +void registerTestTilingInterface() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -33,6 +33,7 @@ MLIRTestRewrite MLIRTestTransformDialect MLIRTestTransforms + MLIRTilingInterfaceTestPasses MLIRVectorTestPasses ) endif() diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -111,6 +111,7 @@ void registerTestSCFUtilsPass(); void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); +void registerTestTilingInterface(); void registerTestTransformDialectInterpreterPass(); void registerTestVectorLowerings(); } // namespace test @@ -206,6 +207,7 @@ mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms(); + mlir::test::registerTestTilingInterface(); mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings(); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1864,6 +1864,7 @@ "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/SCF/Passes.h", "include/mlir/Dialect/SCF/Patterns.h", + "include/mlir/Dialect/SCF/TileUsingInterface.h", "include/mlir/Dialect/SCF/Transforms.h", ], includes = ["include"], @@ -1883,6 +1884,7 @@ ":SCFUtils", ":Support", ":TensorDialect", + ":TilingInterface", ":Transforms", "//llvm:Support", ], @@ -2646,6 +2648,7 @@ exclude = [ "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/SCF/Patterns.h", + "include/mlir/Dialect/SCF/TileUsingInterface.h", "include/mlir/Dialect/SCF/Transforms.h", ], ), @@ -6314,6 +6317,7 @@ "//mlir/test:TestSPIRV", "//mlir/test:TestShapeDialect", "//mlir/test:TestTensor", + "//mlir/test:TestTilingInterface", "//mlir/test:TestTosaDialect", "//mlir/test:TestTransformDialect", "//mlir/test:TestTransforms", @@ -7415,6 +7419,7 @@ "include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h", "include/mlir/Dialect/Linalg/Transforms/HoistPadding.h", "include/mlir/Dialect/Linalg/Transforms/Hoisting.h", + "include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h", "include/mlir/Dialect/Linalg/Transforms/Transforms.h", "include/mlir/Dialect/Linalg/Utils/Utils.h", ], @@ -7450,6 +7455,7 @@ ":TensorTilingInterfaceImpl", ":TensorTransforms", ":TensorUtils", + ":TilingInterface", ":TransformUtils", ":Transforms", ":VectorOps", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -293,6 +293,28 @@ ], ) +cc_library( + name = "TestTilingInterface", + srcs = glob(["lib/Interfaces/TilingInterface/*.cpp"]), + includes = ["lib/Dialect/Test"], + deps = [ + "//llvm:Support", + "//mlir:Affine", + "//mlir:ArithmeticDialect", + "//mlir:FuncDialect", + "//mlir:IR", + "//mlir:LinalgOps", + "//mlir:LinalgTransforms", + "//mlir:MemRefDialect", + "//mlir:Pass", + "//mlir:SCFDialect", + "//mlir:SCFTransforms", + "//mlir:TensorDialect", + "//mlir:TilingInterface", + "//mlir:Transforms", + ], +) + cc_library( name = "TestPass", srcs = glob(["lib/Pass/*.cpp"]),