diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- TilingInterfaceImpl.h - Implementation of TilingInterface ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H +#define MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerTilingInterfaceExternalModels(DialectRegistry ®istry); +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -164,11 +164,11 @@ SmallVector computeTileOffsets(OpBuilder &b, Location loc, ValueRange ivs, ValueRange tileSizes); -/// Compute tile sizes, given a list of loop `ivs`, `tileSizes` and dimension +/// Compute tile sizes, given a list of `tileSizes` and dimension /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the /// corresponding result size is the corresponding value from `sizeBounds`. /// Note: The returned tile sizes are closed intervals. -SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs, +SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange tileSizes, ArrayRef sizeBounds); diff --git a/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h @@ -0,0 +1,60 @@ +//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H +#define MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H + +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Interfaces/TilingInterface.h" + +namespace mlir { +class Operation; +class TilingInterface; +} // namespace mlir + +namespace mlir { +struct SCFTilingResult { + Operation *tiledOp; + SmallVector loops; +}; + +/// Pattern to tile an op that implementas the `TilingInterface` using +/// `scf.for` for iterating over the tiles. +struct TileUsingSCFForOp : public OpInterfaceRewritePattern { + /// Construct a generic pattern applied to all TilingInterface ops that verify + /// `filter`. + TileUsingSCFForOp(MLIRContext *context, TilingOptions options, + TransformationFilter filter = TransformationFilter(), + PatternBenefit benefit = 1); + + /// Construct a generic pattern applied to `opName`. + TileUsingSCFForOp(StringRef opName, MLIRContext *context, + TilingOptions options, + TransformationFilter filter = TransformationFilter(), + PatternBenefit benefit = 1); + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr + returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } + +private: + /// TransformMarker handles special attribute manipulations. + TransformationFilter filter; + /// Options to control tiling; + TilingOptions options; +}; + +} // namespace mlir + +#endif // MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SCF_UTILS_UTILS_H_ #define MLIR_DIALECT_SCF_UTILS_UTILS_H_ +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -31,12 +32,6 @@ class FuncOp; } // namespace func -namespace scf { -class IfOp; -class ForOp; -class ParallelOp; -} // namespace scf - /// Replace the `loop` with `newIterOperands` added as new initialization /// values. `newYieldValuesFn` is a callback that can be used to specify /// the additional values to be yielded by the loop. The number of @@ -56,6 +51,25 @@ ValueRange newIterOperands, const NewYieldValueFn &newYieldValuesFn); +/// Update a perfectly nested loop nest to yield new values from the innermost +/// loop and propagating it up through the loop nest. This function +/// - Expects `loopNest` to be a perfectly nested loop with outer most loop +/// first and innermost loop last. +/// - `newIterOperands` are the initialization values to be used for the +/// outermost loop +/// - `newYielValueFn` is the callback that generates the new values to be +/// yielded from within the innermost loop. +/// - The original loops are not erased, but are left in a "no-op" state where +/// the body of the loop just yields the basic block arguments that correspond +/// to the initialization values of a loop. The original loops are dead after +/// this method. +/// - All uses of the `newIterOperands` within the generated new loop +/// are replaced with the corresponding `BlockArgument` in the loop body. +SmallVector +replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef loopNest, + ValueRange newIterOperands, + NewYieldValueFn newYieldValueFn); + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h --- a/mlir/include/mlir/Interfaces/TilingInterface.h +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -23,4 +23,95 @@ /// Include the ODS generated interface header files. #include "mlir/Interfaces/TilingInterface.h.inc" +namespace mlir { +class PatternRewriter; + +using TileSizeComputationFunction = + std::function(OpBuilder &, Operation *)>; + +/// Options to use to control tiling. +struct TilingOptions { + /// Computation function that returns the tile sizes for each operation. + /// Delayed construction of constant tile sizes should occur to interoperate + /// with folding. + TileSizeComputationFunction tileSizeComputationFunction = nullptr; + + TilingOptions & + setTileSizeComputationFunction(TileSizeComputationFunction fun) { + tileSizeComputationFunction = std::move(fun); + return *this; + } + /// Set the `tileSizeComputationFunction` to return the values `ts`. The + /// values must not fold away when tiling. Otherwise, use a more robust + /// `tileSizeComputationFunction`. + TilingOptions &setTileSizes(const SmallVector &ts) { + tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; + return *this; + } + /// Convenience function to set the `tileSizeComputationFunction` to a + /// function that computes tile sizes at the point they are needed. Allows + /// proper interaction with folding. + TilingOptions &setTileSizes(ArrayRef ts); +}; + +/// Helper class to control application of transformation patterns. +/// Control comes in 2 forms: +/// 1. attribute matching and setting behavior using the attribute named +/// `kTransformMarker`. This can be used to build a state machine +/// using attributes and incrementally applying patterns to advance states. +/// 2. filter function, which is a simple lambda on the Operation* that +/// returns a LogicalResult. +/// These are added for maintaining uses of tiling patterns with pattern +/// rewrites using, say `applyPatternAndFoldGreedily`. When all consumers move +/// to use the PDL-based patterns, this will not be needed anymore. +struct TransformationFilter { + using FilterFunction = std::function; + + explicit TransformationFilter(ArrayRef matchDisjunction = {}, + Optional replacement = None); + + explicit TransformationFilter(const FilterFunction &f, + ArrayRef matchDisjunction = {}, + Optional replacement = None); + + TransformationFilter(TransformationFilter &&) = default; + TransformationFilter(const TransformationFilter &) = default; + LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; + void replaceTransformationFilter(PatternRewriter &rewriter, + Operation *op) const; + bool hasReplacementFilter(Operation *op) const; + + TransformationFilter &addFilter(const FilterFunction &f) { + if (f) + filters.push_back(f); + return *this; + } + + template + TransformationFilter &addOpFilter() { + return addFilter( + [](Operation *op) { return success(isa(op)); }); + } + + TransformationFilter &addOpNameFilter(StringRef opName) { + return addFilter([opName](Operation *op) { + return success(op->getName().getStringRef() == opName); + }); + } + + TransformationFilter &setMatchByDefault() { + matchByDefault = true; + return *this; + } + +private: + SmallVector filters; + SmallVector matchDisjunction; + Optional replacement; + /// When set to true, if the attribute is not set, it will be treated as + /// a match. Default is false. + bool matchByDefault; +}; +} // namespace mlir + #endif // MLIR_INTERFACES_TILINGINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -98,6 +98,28 @@ /*defaultImplementation=*/[{ return {}; }] + >, + InterfaceMethod< + /*desc=*/[{ + Method to return the position of the result tile computed by the tiled operation. + + Specifies what tile of the result of the original tensor is computed + by the tiled implementation. Expects the same `offsets` and `sizes` as + used to obtain the tiled implementation of the operation. + }], + /*retType=*/"LogicalResult", + /*methodName=*/"getResultTilePosition", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$resultNumber, + "ArrayRef ":$offsets, + "ArrayRef ":$sizes, + "SmallVector &":$resultOffsets, + "SmallVector &":$resultSizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] > ]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ SparseTensorRewriting.cpp SplitReduction.cpp Tiling.cpp + TilingInterfaceImpl.cpp Transforms.cpp Vectorization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -320,8 +320,7 @@ // Compute offsets and sizes of ExtractSliceOp. SmallVector offsets = computeTileOffsets(b, loc, localIvs, tileSizes); - SmallVector sizes = - computeTileSizes(b, loc, localIvs, tileSizes, allDims); + SmallVector sizes = computeTileSizes(b, loc, tileSizes, allDims); // Create ExtractSliceOp: Extract a tile from the tensor::PadOp. // Note: The tensor::PadOp is located outside of the loop nest. It is // later moved inside by ExtractSliceOfPadTensorSwapPattern. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -0,0 +1,158 @@ +//===- TilingInterfaceImpl.cpp - Implemenetation of TilingInterface ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/TilingInterface.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +/// External model implementation of TilingInterface for LinalgOps. An external +/// model implementation is used for now till the use of `TilingInterface` is +/// on-par with the current Linalg tiling + fusion patterns. Once it is +/// maybe possible to move this into the op-definition (though there are +/// advantages to leaving it as an external model) +template +struct LinalgOpTilingInterface + : public TilingInterface::ExternalModel, + LinalgOpTy> { + + /// Return the destination operands. + SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { + return llvm::cast(op).getOutputOperands(); + } + + /// Return the loop iterator type. + SmallVector getLoopIteratorTypes(Operation *op) const { + LinalgOpTy concreteOp = cast(op); + return llvm::to_vector( + llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { + return strAttr.cast().getValue(); + })); + } + + /// Return the iteration domain range. + SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); + AffineMap map = linalgOp.getShapesToLoopsMap(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + return llvm::to_vector(llvm::map_range( + applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) { + return Range{zero, v, one}; + })); + } + + // Instantiate the tiled implementation of the operation. + SmallVector + getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest, + ArrayRef offsets, + ArrayRef sizes, + bool tileDestOperands) const { + // Leave the `sizeBounds` value empty. That is only needed when the `sizes` + // specified could lead to out of bounds accesses. + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector tiledOperands = makeTiledShapes( + b, loc, linalgOp, valuesToTile, + getValueOrCreateConstantIndexOp(b, loc, offsets), + getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); + + SmallVector resultTensorTypes = llvm::to_vector(llvm::map_range( + linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { + return tiledOperands[opOperand->getOperandNumber()].getType(); + })); + + Operation *tiledOp = + linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); + + return {tiledOp}; + } + + // Return the details of the output tile generated by the tiled + // implementation. + LogicalResult + getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) const { + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + + AffineExpr d0; + bindDims(b.getContext(), d0); + + auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc, + AffineExpr expr, + ValueRange operands) -> Value { + AffineMap map = AffineMap::inferFromExprList({expr}).front(); + SmallVector normalizedOperands(operands.begin(), operands.end()); + mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands); + canonicalizeMapAndOperands(&map, &normalizedOperands); + return builder.createOrFold(loc, map, normalizedOperands); + }; + + SmallVector sizeVals = + getValueOrCreateConstantIndexOp(b, loc, sizes); + SmallVector subShapeSizes = + llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) { + return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v); + })); + OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); + Value sliceOpResult = + makeTiledShape(b, loc, outOperand->get(), sizeVals, + linalgOp.getTiedIndexingMap(outOperand), + getValueOrCreateConstantIndexOp(b, loc, offsets), + /*ubs*/ {}, subShapeSizes, true); + auto sliceOp = sliceOpResult.getDefiningOp(); + if (!sliceOp) + return failure(); + resultOffsets = sliceOp.getMixedOffsets(); + resultSizes = sliceOp.getMixedSizes(); + return success(); + } +}; + +} // namespace + +template +static void registerOne(MLIRContext *ctx) { + OpType::template attachInterface>(*ctx); +} + +/// Variadic helper function. +template +static void registerAll(MLIRContext *ctx) { + // FIXME: In c++17 this can be simplified by using 'fold expressions'. + (void)std::initializer_list{0, (registerOne(ctx), 0)...}; +} + +#define GET_OP_LIST + +void mlir::linalg::registerTilingInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { + registerOne(ctx); + registerAll< +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(ctx); + }); +} diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -893,7 +893,7 @@ return offsets; } -SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs, +SmallVector computeTileSizes(OpBuilder &b, Location loc, ValueRange tileSizes, ArrayRef sizeBounds) { SmallVector sizes; @@ -923,7 +923,7 @@ // that define tile subshapes. SmallVector lbs = computeTileOffsets(b, loc, ivs, tileSizes); SmallVector subShapeSizes = - computeTileSizes(b, loc, ivs, tileSizes, sizeBounds); + computeTileSizes(b, loc, tileSizes, sizeBounds); assert(static_cast(valuesToTile.size()) == linalgOp.getNumInputsAndOutputs() && diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ ParallelLoopFusion.cpp ParallelLoopTiling.cpp StructuralTypeConversions.cpp + TileUsingInterface.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -0,0 +1,240 @@ +//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the tiling using TilingInterface. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/TileUsingInterface.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tile-using-interface" + +using namespace mlir; + +/// Generate an empty loop nest that represents the tiled loop nest shell. +/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. +/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. +/// - In `offsets` and `sizes` return the multi-dimensional offset and size of +/// the +/// tile processed within the inner most loop. +static SmallVector +generateTileLoopNest(OpBuilder &builder, Location loc, + ArrayRef loopRanges, ArrayRef tileSizeVals, + SmallVector &offsets, + SmallVector &sizes) { + assert(!loopRanges.empty() && "expected at least one loop range"); + assert(loopRanges.size() == tileSizeVals.size() && + "expected as many tile sizes as loop ranges"); + OpBuilder::InsertionGuard guard(builder); + SmallVector loops; + offsets.resize(loopRanges.size()); + sizes.resize(loopRanges.size()); + + // The tile size to use (to avoid out of bounds access) is minimum of + // `tileSize` and `ub - iv`, where `iv` is the induction variable + // of the tiled loop. + AffineExpr s0, s1, d0; + bindDims(builder.getContext(), d0); + bindSymbols(builder.getContext(), s0, s1); + AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); + + for (auto loopRange : llvm::enumerate(loopRanges)) { + // No loops if tile size is zero. Set offset and size to the loop + // offset and size. + if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { + offsets[loopRange.index()] = loopRange.value().offset; + sizes[loopRange.index()] = loopRange.value().size; + continue; + } + + auto loop = builder.create( + loc, loopRange.value().offset, loopRange.value().size, + loopRange.value().stride, ValueRange{}, + [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, + ValueRange /*iterArgs*/) { + Value boundedTileSize = builder.create( + bodyLoc, minMap, + ValueRange{iv, tileSizeVals[loopRange.index()], + loopRange.value().size}); + sizes[loopRange.index()] = boundedTileSize; + builder.create(loc); + }); + offsets[loopRange.index()] = loop.getInductionVar(); + loops.push_back(loop); + builder.setInsertionPoint(loop.getBody()->getTerminator()); + } + return loops; +} + +/// tiling pattern. +TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, + TilingOptions options, + TransformationFilter f, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(std::move(f)), options(std::move(options)) {} + +TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, MLIRContext *context, + TilingOptions options, + TransformationFilter f, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(f.addOpNameFilter(opName)), options(std::move(options)) {} + +FailureOr +TileUsingSCFForOp::returningMatchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); + + if (!options.tileSizeComputationFunction) + return failure(); + + // 1. Get the range of the loops that are represented by the operation. + SmallVector iterationDomain = op.getIterationDomain(rewriter); + size_t numLoops = iterationDomain.size(); + if (numLoops == 0) + return failure(); + + // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" + // skips tiling a particular dimension. This convention is significantly + // simpler to handle instead of adjusting affine maps to account for missing + // dimensions. + SmallVector tileSizeVector = + options.tileSizeComputationFunction(rewriter, op); + if (tileSizeVector.size() < iterationDomain.size()) { + auto zero = rewriter.create(op.getLoc(), 0); + tileSizeVector.append(numLoops - tileSizeVector.size(), zero); + } + + // 3. Materialize an empty loop nest that iterates over the tiles. These loops + // for now do not return any values even if the original operation has + // results. + SmallVector offsets, sizes; + SmallVector loops = generateTileLoopNest( + rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); + + if (!loops.empty()) { + LLVM_DEBUG({ + llvm::errs() << "LoopNest shell :\n"; + loops.front().dump(); + llvm::errs() << "\n"; + }); + } + + // 4. Generate the tiled implementation within the inner most loop. + if (!loops.empty()) + rewriter.setInsertionPoint(loops.back().getBody()->getTerminator()); + SmallVector tiledImplementation = op.getTiledImplementation( + rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); + if (tiledImplementation.size() != 1) { + return rewriter.notifyMatchFailure( + op, "expected tiled implementation to return a single op"); + } + + if (!loops.empty()) { + LLVM_DEBUG({ + llvm::errs() << "After tiled implementation :\n"; + loops.front().dump(); + llvm::errs() << "\n"; + }); + } + + // 5. If the original operations has results, modify the loop nest to yield + // the replacement values. + if (op->getNumResults() != 0) { + SmallVector replacements; + if (loops.empty()) { + // 5a. If there were no loops, the tiled implementation results are the + // replacements. + replacements.assign(tiledImplementation[0]->result_begin(), + tiledImplementation[0]->result_end()); + } else { + // 5b. `scf.for` with tensor semantics requires the loop nest to yield the + // replacement values using destructive updates. Use the `TilingInterface` + // to get the position of the result tiles and use that to generate the + // destructive update pattern, i.e., + // + // ```mlir + // scf.for %iv0 = ... { + // %0 = tiled_op + // } + // ``` + // + // is transformed to + // + // ```mlir + // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { + // %0 = tiled_op + // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] + // scf.yield %1 + // } + // ``` + NewYieldValueFn yieldValueFn = + [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) -> SmallVector { + SmallVector yieldedValues; + Attribute one = b.getIndexAttr(1); + for (auto resultNum : llvm::seq(0, op->getNumResults())) { + SmallVector resultTileOffsets, resultTileSizes; + if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, + resultTileOffsets, + resultTileSizes))) { + op.emitOpError("unable to get position of result ") + << resultNum << " of the tiled implementation"; + return {}; + } + SmallVector resultTileStrides(resultTileOffsets.size(), + one); + Value yieldedValue = b.create( + op->getLoc(), tiledImplementation[0]->getResult(resultNum), + newBBArgs[resultNum], resultTileOffsets, resultTileSizes, + resultTileStrides); + yieldedValues.push_back(yieldedValue); + } + return yieldedValues; + }; + SmallVector newLoops = replaceLoopNestWithNewYields( + rewriter, loops, op.getDestinationOperands(rewriter), yieldValueFn); + for (auto loop : llvm::enumerate(loops)) { + rewriter.eraseOp(loop.value()); + loops[loop.index()] = newLoops[loop.index()]; + } + scf::ForOp outerMost = loops.front(); + replacements.assign(outerMost->result_begin(), outerMost->result_end()); + + if (!loops.empty()) { + LLVM_DEBUG({ + llvm::errs() << "After yielding values :\n"; + loops.front().dump(); + llvm::errs() << "\n"; + }); + } + } + + rewriter.replaceOp(op, replacements); + } else { + rewriter.eraseOp(op); + } + + SCFTilingResult tilingResult; + tilingResult.tiledOp = tiledImplementation[0]; + tilingResult.loops = std::move(loops); + filter.replaceTransformationFilter(rewriter, tilingResult.tiledOp); + return tilingResult; +} diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; @@ -101,6 +102,31 @@ return newLoop; } +SmallVector mlir::replaceLoopNestWithNewYields( + OpBuilder &builder, ArrayRef loopNest, + ValueRange newIterOperands, NewYieldValueFn newYieldValueFn) { + if (loopNest.empty()) + return {}; + SmallVector newLoopNest(loopNest.size()); + + newLoopNest.back() = replaceLoopWithNewYields( + builder, loopNest.back(), newIterOperands, newYieldValueFn); + + for (unsigned loopDepth : + llvm::reverse(llvm::seq(0, loopNest.size() - 1))) { + NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc, + ArrayRef innerNewBBArgs) { + SmallVector newYields( + newLoopNest[loopDepth + 1]->getResults().take_back( + newIterOperands.size())); + return newYields; + }; + newLoopNest[loopDepth] = replaceLoopWithNewYields( + builder, loopNest[loopDepth], newIterOperands, fn); + } + return newLoopNest; +} + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/lib/Interfaces/TilingInterface.cpp b/mlir/lib/Interfaces/TilingInterface.cpp --- a/mlir/lib/Interfaces/TilingInterface.cpp +++ b/mlir/lib/Interfaces/TilingInterface.cpp @@ -12,6 +12,90 @@ #include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" + using namespace mlir; #include "mlir/Interfaces/TilingInterface.cpp.inc" + +static const StringLiteral kTransformMarker = "__internal_transform__"; + +mlir::TransformationFilter::TransformationFilter( + ArrayRef matchDisjunction, Optional replacement) + : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), + replacement(replacement), matchByDefault(false) {} + +mlir::TransformationFilter::TransformationFilter( + const FilterFunction &f, ArrayRef matchDisjunction, + Optional replacement) + : filters(), + matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), + replacement(replacement), matchByDefault(false) { + if (f) + filters.push_back(f); +} + +LogicalResult +mlir::TransformationFilter::checkAndNotify(PatternRewriter &rewriter, + Operation *op) const { + if (llvm::any_of(filters, + [&](const FilterFunction &f) { return failed(f(op)); })) + return failure(); + + auto attr = op->template getAttrOfType(kTransformMarker); + + if (!attr) { + // 1. Has no filter case and matchDisjunction is empty. + if (matchDisjunction.empty() || matchByDefault) + return success(); + + // 2. Has no filter but was expecting a filter. + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << " does not have any filter from list: "; + interleaveComma(matchDisjunction, diag); + }); + } + + // 4. Match explicit filter. + for (auto filter : matchDisjunction) + if (attr.getValue() == filter) + return success(); + + // 5. Fail to match. + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << " does not have any filter from list: "; + interleaveComma(matchDisjunction, diag); + }); +} + +void mlir::TransformationFilter::replaceTransformationFilter( + PatternRewriter &rewriter, Operation *op) const { + if (replacement.hasValue()) + op->setAttr(kTransformMarker, replacement.getValue()); + else + op->removeAttr(rewriter.getStringAttr(kTransformMarker)); +} + +bool mlir::TransformationFilter::hasReplacementFilter(Operation *op) const { + if (!replacement) + return false; + auto attr = op->getAttr(kTransformMarker).dyn_cast(); + return attr && attr == replacement.getValue(); +} + +TilingOptions &mlir::TilingOptions::setTileSizes(ArrayRef ts) { + assert(!tileSizeComputationFunction && "tile sizes already set"); + SmallVector tileSizes(ts.begin(), ts.end()); + tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart( + &op->getParentOfType().getBody().front()); + return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { + Value v = b.create(op->getLoc(), s); + return v; + })); + }; + return *this; +} diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -0,0 +1,136 @@ +// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s + +func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul {__internal_transform__ = "simple_gemm"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func.func @simple_matmul( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[OUTER:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]]) +// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]] +// CHECK: %[[INNER:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]]) +// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]] +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT1]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT1]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[INNER]] +// CHECK: return %[[OUTER]] + +// ----- + +func.func @simple_matmul_memref(%arg0 : memref, %arg1 : memref, + %arg2 : memref) { + linalg.matmul {__internal_transform__ = "simple_gemm_memref"} + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> +// CHECK: func.func @simple_matmul_memref( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index +// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]] +// CHECK: %[[TS_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]] +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]] +// CHECK: %[[TS_N:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]] +// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]] +// CHECK: %[[TS_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[K]]] +// CHECK-DAG: %[[LHS_TILE:.+]] = memref.subview %[[ARG0]] +// CHECK-SAME: [%[[IV0]], %[[IV2]]] [%[[TS_M]], %[[TS_K]]] [1, 1] +// CHECK-DAG: %[[RHS_TILE:.+]] = memref.subview %[[ARG1]] +// CHECK-SAME: [%[[IV2]], %[[IV1]]] [%[[TS_K]], %[[TS_N]]] [1, 1] +// CHECK-DAG: %[[OUT_TILE:.+]] = memref.subview %[[ARG2]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_M]], %[[TS_N]]] [1, 1] +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[OUT_TILE]] : + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<200x128x300xf32>) { + %init0 = linalg.init_tensor [128, 300, 200] : tensor<128x300x200xf32> + %init1 = linalg.init_tensor [200, 128, 300] : tensor<200x128x300xf32> + %0:2 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel"]} + {__internal_transform__ = "parallel_generic_transpose"} + ins(%arg0 : tensor<128x200x300xf32>) + outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<200x128x300xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + linalg.yield %b0, %b0 : f32, f32 + } -> (tensor<128x300x200xf32>, tensor<200x128x300xf32>) + return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<200x128x300xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func.func @multi_result( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index +// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200] +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [200, 128, 300] +// CHECK: %[[OUTER:.+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C1]] +// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]]) +// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]] +// CHECK: %[[INNER:.+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C1]] +// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]]) +// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C300]]] +// CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, %[[TS_X]]] [1, 1, 1] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]] +// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]] +// CHECK-SAME: [0, %[[IV0]], %[[IV1]]] [200, %[[TS_Y]], %[[TS_X]]] [1, 1, 1] +// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG_TILE]] : +// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] : +// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]] +// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1] +// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]] +// CHECK-SAME: [0, %[[IV0]], %[[IV1]]] [200, %[[TS_Y]], %[[TS_X]]] [1, 1, 1] +// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]] +// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1 +// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1 diff --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt --- a/mlir/test/lib/CMakeLists.txt +++ b/mlir/test/lib/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(Analysis) add_subdirectory(Conversion) add_subdirectory(Dialect) +add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Pass) add_subdirectory(Reducer) diff --git a/mlir/test/lib/Interfaces/CMakeLists.txt b/mlir/test/lib/Interfaces/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TilingInterface) diff --git a/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_library(MLIRTilingInterfaceTestPasses + TestTilingInterface.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRAffine + MLIRArithmetic + MLIRLinalg + MLIRLinalgTransforms + MLIRMemRef + MLIRSCF + MLIRSCFTransforms + MLIRTensor + ) diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -0,0 +1,84 @@ +//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing tiling operations using +// `TilingInterface`. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +namespace { +struct TestTilingInterfacePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) + + TestTilingInterfacePass() = default; + TestTilingInterfacePass(const TestTilingInterfacePass &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + linalg::registerTilingInterfaceExternalModels(registry); + } + StringRef getArgument() const final { return "test-tiling-interface"; } + StringRef getDescription() const final { + return "Test tiling using TilingInterface"; + } + + void runOnOperation() override; +}; +} // namespace + +static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { + auto addPatternForTiling = [&](ArrayRef tileSizes, + StringRef filterName) { + TilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes); + TransformationFilter filter(StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); + patterns.add(context, tilingOptions, filter); + }; + // 1. Tiling M and N dims of `linalg.matmul` on tensors. + addPatternForTiling({10, 20}, "simple_gemm"); + // 2. Tiling M, N and K of `linalg.matmul` on buffers. + addPatternForTiling({10, 20, 30}, "simple_gemm_memref"); + // 3. Tiling 3D parallel generic op which implements a transpose + addPatternForTiling({10, 0, 20}, "parallel_generic_transpose"); +} + +void TestTilingInterfacePass::runOnOperation() { + MLIRContext *context = &getContext(); + + RewritePatternSet tilingPatterns(context); + addTestPatterns(context, tilingPatterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(tilingPatterns)))) + return signalPassFailure(); +} + +namespace mlir { +namespace test { +void registerTestTilingInterface() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -33,6 +33,7 @@ MLIRTestRewrite MLIRTestTransformDialect MLIRTestTransforms + MLIRTilingInterfaceTestPasses MLIRVectorTestPasses ) endif() diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -111,6 +111,7 @@ void registerTestSCFUtilsPass(); void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); +void registerTestTilingInterface(); void registerTestTransformDialectInterpreterPass(); void registerTestVectorLowerings(); } // namespace test @@ -206,6 +207,7 @@ mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms(); + mlir::test::registerTestTilingInterface(); mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings(); }