diff --git a/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/TileUsingInterface.h @@ -10,9 +10,12 @@ #define MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" +#include + namespace mlir { class Operation; class PatternRewriter; @@ -55,7 +58,7 @@ SmallVector loops; }; -/// Pattern to tile an op that implementas the `TilingInterface` using +/// Pattern to tile an op that implements the `TilingInterface` using /// `scf.for` for iterating over the tiles. struct TileUsingSCFForOp : public OpInterfaceRewritePattern { /// Construct a generic pattern applied to all TilingInterface ops. @@ -81,6 +84,56 @@ SCFTilingOptions options; }; +/// Pattern to tile and fuse a sequence of operations, by tiling the consumer +/// and fusing its producers. Note that this assumes that it is valid to +/// tile+fuse the producer into the innermost tiled loop. Its up to the caller +/// to ensure that the tile sizes provided make this fusion valid. +/// +/// For example, for the following sequence +/// +/// ```mlir +/// %0 = linalg.fill ... +/// %1 = linalg.matmul ... outs(%0 : ...) ... +/// ``` +/// +/// it is legal to fuse the fill with the matmul only if the matmul is tiled +/// along the parallel dimensions and not the reduction dimension, i.e. the tile +/// size for the reduction dimension should be 0. +struct SCFTileAndFuseResult { + SmallVector tiledAndFusedOps; + SmallVector loops; +}; +struct TileConsumerAndFuseProducersUsingSCFForOp + : public OpInterfaceRewritePattern { + + /// Construct a generic pattern applied to all TilingInterface ops. + TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, + SCFTilingOptions options, + PatternBenefit benefit = 1); + + /// Construct a generic pattern applied to `opName`. + TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, + MLIRContext *context, + SCFTilingOptions options, + PatternBenefit benefit = 1); + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr + returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } + +private: + /// This pattern uses the tiling pattern. Instead of using inheritance, use + /// the patterns as private object that is instantiated at the same time as + /// this pattern. + TileUsingSCFForOp tilingPattern; +}; + } // namespace scf } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -20,6 +21,14 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); +/// Pattern to swap an `tensor.extract_slice` with its producer when the +/// producer implements the `TilingInterface`. The pattern itself does not +/// provide a mechanism to control where the application happens. With use of +/// transform dialect that control is done within the transform dialect. Other +/// use cases can inherit from this pattern and add necessary controls. +FailureOr replaceExtractSliceWithTiledProducer( + OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); + } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -120,7 +120,48 @@ /*defaultImplementation=*/[{ return failure(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Method to generate the code that produces a tile of the result. + + Generates the IR that computes the tile of a result of the + operation. The `offsets` and `sizes` describes the tile of + the output required. This is different from + `getTiledImplementation` which generates the tiled + implementation of the operation given a tile of the + iteration space. This method generates a tiled + implementation of the operation based on the tile of the + result required. This method enables fusion by using tile + and fuse. The method returns failure if the operation can be + tiled to generate the result tile. In practical terms this + implies it cannot be tiled and fused with its consumers. + + - `dest` are the Value into which the result of the tiled + operation is to be inserted into. The type of the `dest` + Values is same as the types returned by + `getDestinationOperands` method. + - `offsets` provides the offset of the tile within the + iteration space + - `sizes` provides the size of the tile. + - `tileDestOperands` specifies whether to also tile `dest` operands + or not. Avoiding tiling `dest` operands can be useful for + composition with various looping container ops. + }], + /*retType=*/"FailureOr", + /*methodName=*/"generateResultTileValue", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$resultNumber, + "ValueRange":$dest, + "ArrayRef":$offsets, + "ArrayRef":$sizes, + "bool":$tileDestOperands), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] > - ]; + ]; } #endif // MLIR_TILINGINTERFACE diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -30,7 +30,6 @@ struct LinalgOpTilingInterface : public TilingInterface::ExternalModel, LinalgOpTy> { - /// Return the destination operands. SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { return llvm::cast(op).getOutputOperands(); @@ -47,6 +46,8 @@ /// Return the iteration domain range. SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); @@ -129,16 +130,65 @@ resultSizes = sliceOp.getMixedSizes(); return success(); } + + FailureOr generateResultTileValue(Operation *op, OpBuilder &b, + unsigned resultNumber, + ValueRange dest, + ArrayRef offsets, + ArrayRef sizes, + bool tileDestOperands) const { + auto linalgOp = cast(op); + + // Check that the indexing map used for the output is a projected + // permutation. This could be relaxed with a more general approach that can + // map the offsets and sizes from the result to iteration space tiles + // (filling in full extent for dimensions not used to access the result). + AffineMap indexingMap = + linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber)); + if (!indexingMap.isProjectedPermutation()) { + return op->emitOpError( + "uhandled tiled implementation generation when result is not " + "accessed using a permuted projection"); + } + + auto numLoops = linalgOp.getNumLoops(); + auto tilingInterfaceOp = cast(op); + SmallVector iterationTileOffsets(numLoops), + iterationTileSizes(numLoops); + if (!indexingMap.isPermutation()) { + SmallVector iterationDomain = + tilingInterfaceOp.getIterationDomain(b); + for (auto range : llvm::enumerate(iterationDomain)) { + iterationTileOffsets[range.index()] = range.value().offset; + iterationTileSizes[range.index()] = range.value().size; + } + } + for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) { + unsigned dimPosition = + resultExpr.value().cast().getPosition(); + iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; + iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; + } + + SmallVector tiledOp = tilingInterfaceOp.getTiledImplementation( + b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands); + if (tiledOp.size() != 1) + return op->emitOpError("failed to generate tiled implementation"); + + return tiledOp[0]->getResult(resultNumber); + } }; } // namespace -template static void registerOne(MLIRContext *ctx) { +template +static void registerOne(MLIRContext *ctx) { OpType::template attachInterface>(*ctx); } /// Variadic helper function. -template static void registerAll(MLIRContext *ctx) { +template +static void registerAll(MLIRContext *ctx) { // FIXME: In c++17 this can be simplified by using 'fold expressions'. (void)std::initializer_list{0, (registerOne(ctx), 0)...}; } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -42,6 +42,10 @@ return *this; } +//===----------------------------------------------------------------------===// +// TileUsingSCFForOp pattern implementation. +//===----------------------------------------------------------------------===// + /// Generate an empty loop nest that represents the tiled loop nest shell. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. @@ -247,3 +251,159 @@ rewriter.replaceOp(op, tilingResult.loops.front().getResults()); return tilingResult; } + +//===----------------------------------------------------------------------===// +// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. +//===----------------------------------------------------------------------===// + +scf::TileConsumerAndFuseProducersUsingSCFForOp:: + TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, + scf::SCFTilingOptions options, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + tilingPattern(context, std::move(options)) {} + +scf::TileConsumerAndFuseProducersUsingSCFForOp:: + TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, + MLIRContext *context, + scf::SCFTilingOptions options, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + tilingPattern(context, std::move(options)) {} + +/// Return the `Value` that is defined by an operation that implements +/// the `TilingInterface`. Looks through `iter_args` of scf.for nest +/// if required. +static Optional getFusableProducer(Value v) { + while (auto blockArg = v.dyn_cast()) { + auto loopOp = dyn_cast(blockArg.getOwner()->getParentOp()); + if (!loopOp) + return llvm::None; + v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); + } + auto result = v.dyn_cast(); + if (!result) + return llvm::None; + auto fusableOp = dyn_cast(result.getOwner()); + if (!fusableOp) + return llvm::None; + return result; +} + +FailureOr +scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( + TilingInterface op, PatternRewriter &rewriter) const { + // This transformation is only valid for ops that return values (i.e. not + // valid to use with operations that have memref operands). + if (!op->getNumResults()) { + return rewriter.notifyMatchFailure( + op, "invalid pattern for op with no results"); + } + + // 1. First tile the consumer. + SCFTileAndFuseResult tileAndFuseResult; + { + FailureOr tilingResult = + tilingPattern.returningMatchAndRewrite(op, rewriter); + if (failed(tilingResult)) { + return failure(); + } + tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); + tileAndFuseResult.loops = std::move(tilingResult->loops); + } + + // 2. Typically, the operands of the tiled operation are slices of the + // operands of the untiled operation. These are expressed in IR using + // `tensor.extract_slice` operations with source being the operands of the + // untiled operation. Create a worklist of these `tensor.extract_slice` + // operations. If the producers of the source of the `tensor.extract_slice` + // can be tiled such that the tiled value is generated in-place, that + // effectively tiles + fuses the operations. + auto addCandidateSlices = [](Operation *fusedOp, + std::deque &candidates) { + for (Value operand : fusedOp->getOperands()) + if (auto sliceOp = operand.getDefiningOp()) + candidates.push_back(sliceOp); + }; + + std::deque candidates; + addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); + OpBuilder::InsertionGuard g(rewriter); + while (!candidates.empty()) { + // 2a. Traverse the slices in BFS fashion. + tensor::ExtractSliceOp candidateSliceOp = candidates.front(); + candidates.pop_front(); + + // 2b. Get the producer of the source (potentially walking through + // `iter_args` of nested `scf.for`) + Optional fusableProducer = + getFusableProducer(candidateSliceOp.source()); + if (!fusableProducer) + continue; + + // 2c. Generate the tiled implementation of the producer of the source + rewriter.setInsertionPoint(candidateSliceOp); + FailureOr fusedProducerValue = + tensor::replaceExtractSliceWithTiledProducer( + rewriter, candidateSliceOp, fusableProducer.getValue()); + if (failed(fusedProducerValue)) + continue; + rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue()); + + // 2d. The operands of the fused producer might themselved be slices of + // values produced by operations that implement the `TilingInterface`. + // Add these operations to the worklist. + Operation *fusedProducer = fusedProducerValue->getDefiningOp(); + tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); + addCandidateSlices(fusedProducer, candidates); + + // 2e. If the operation being fused creates a value that is used as `outs` + // in the tiled operation, the result of the unfused operation will be + // used in the `iter_args` of the tiled loop generated. When the + // operation is fused, this use in `iter_args` needs to be modified to + // use the destination of the fused operation. For example, starting + // with + // + // ```mlir + // %0 = linalg.init_tensor ... + // %1 = linalg.fill ... outs(%0:...)... + // %2 = linalg.matmul ... outs(%1:...).... + // ``` + // + // First the `linalg.matmul` gets tiled + // + // ```mlir + // %0 = linalg.init_tensor + // %1 = linalg.fill + // %2 = scf.for .... iter_args(%arg0 = %1)... + // ... + // ... = linalg.matmul ... + // + // ``` + // + // When the `linalg.fill` gets fused, the `iter_args` needs to be + // modified + // + // ```mlir + // %0 = linalg.init_tensor + // %1 = scf.for ... iter_args(%arg0 = %0)... + // ... + // %2 = linalg.fill ... + // %3 = linalg.matmul ... outs(%2: ...)... + // ``` + TilingInterface unfusedProducerOp = + cast(fusableProducer->getOwner()); + scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); + SmallVector unfusedProducerOpDestValues = + unfusedProducerOp.getDestinationOperands(rewriter); + for (OpOperand &uses : unfusedProducerOp->getUses()) { + if (uses.getOwner() == outerMostTiledLoop.getOperation()) { + unsigned resultNumber = uses.get().cast().getResultNumber(); + unsigned operandNumber = uses.getOperandNumber(); + outerMostTiledLoop->setOperand( + operandNumber, unfusedProducerOpDestValues[resultNumber]); + } + } + } + return tileAndFuseResult; +} diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp SplitPadding.cpp + SwapExtractSliceWithProducer.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms @@ -18,5 +19,6 @@ MLIRPass MLIRSCFDialect MLIRTensorDialect + MLIRTilingInterface MLIRTransforms ) diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp @@ -0,0 +1,43 @@ +//===- SwapExtractSliceWithProducer.cpp - Swapping `tensor.extract_slice` ---=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Swap a `tensor.extract_slice` with the producer of the source if the producer +// implements the `TilingInterface`. When used in conjunction with tiling this +// effectively tiles + fuses the producer with its consumer. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Interfaces/TilingInterface.h" + +using namespace mlir; + +FailureOr tensor::replaceExtractSliceWithTiledProducer( + OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) { + auto producerOp = dyn_cast(producer.getOwner()); + if (!producerOp) + return failure(); + + // `TilingInterface` currently only supports strides being 1. + if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 1); + })) + return failure(); + + FailureOr tiledResult = producerOp.generateResultTileValue( + builder, producer.getResultNumber(), + producerOp.getDestinationOperands(builder), sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes(), true); + if (failed(tiledResult)) + return failure(); + + return tiledResult.getValue(); +} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -0,0 +1,185 @@ +// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s + +func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor + %gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%fill : tensor) -> tensor + return %gemm : tensor +} +// CHECK: func.func @gemm_fill_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] +// CHECK: %[[FILL_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[FILL_TILE]] : +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK scf.yield %[[INSERT]] + +// ----- + +func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor + %gemm = linalg.matmul + ins(%arg0, %arg1 : tensor, tensor) + outs(%fill : tensor) -> tensor + %generic = linalg.generic { + __internal_linalg_transform__ = "fusion", + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%gemm, %arg2 : tensor, tensor) outs(%init : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + return %generic : tensor +} +// CHECK: func.func @gemm_generic_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] +// CHECK: %[[FILL_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[FILL_TILE]] : +// CHECK-DAG: %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]] +// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] : +// CHECK-SAME: outs(%[[OUTS_TILE]] : +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK scf.yield %[[INSERT]] + +// ----- + +func.func @gemm_gemm_fusion(%lhs0 : tensor, %rhs0 : tensor, %rhs1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %lhs0, %c0 : tensor + %d1 = tensor.dim %rhs0, %c1 : tensor + %init0 = linalg.init_tensor [%d0, %d1] : tensor + %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor + %gemm0 = linalg.matmul + ins(%lhs0, %rhs0 : tensor, tensor) outs(%fill0 : tensor) -> tensor + %d2 = tensor.dim %rhs1, %c1 : tensor + %init1 = linalg.init_tensor [%d0, %d2] : tensor + %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor + %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"} + ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor + return %gemm1 : tensor +} +// CHECK: func.func @gemm_gemm_fusion( +// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]] +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]] +// CHECK: %[[INIT1:.+]] = linalg.init_tensor [%[[D0]], %[[D2]]] +// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]]) +// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] +// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]], 0] +// CHECK: %[[FILL0_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_TILE]] : +// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : +// CHECK-SAME: outs(%[[FILL0_TILE]] : +// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] +// CHECK: %[[FILL1_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT1_TILE]] : +// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : +// CHECK-SAME: outs(%[[FILL1_TILE]] : +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0] +// CHECK scf.yield %[[INSERT]] + +// ----- + +func.func @gemm_transpose_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %init0 = linalg.init_tensor [%d0, %d1] : tensor + %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor + %gemm = linalg.matmul + ins(%arg0, %arg1 : tensor, tensor) + outs(%fill : tensor) -> tensor + %init1 = linalg.init_tensor [%d1, %d0] : tensor + %transpose = linalg.generic { + __internal_linalg_transform__ = "fusion", + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], + iterator_types = ["parallel", "parallel"]} + ins(%gemm : tensor) outs(%init1 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + linalg.yield %b0 : f32 + } -> tensor + return %transpose : tensor +} +// CHECK: func.func @gemm_transpose_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]] +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT1]]) +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]]] +// CHECK: %[[FILL_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_TILE]] : +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[FILL_TILE]] : +// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]] +// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GEMM_TILE]] : +// CHECK-SAME: outs(%[[OUTS_TILE]] : +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] +// CHECK scf.yield %[[INSERT]] diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -split-input-file %s | FileCheck %s func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -29,8 +29,9 @@ namespace { -/// Construct a generic pattern applied to all TilingInterface ops that verify -/// `filter`. +/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using +/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while +/// using a `filter` to avoid recursive application. struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { TestTileUsingSCFForOpWithFilter(MLIRContext *context, scf::SCFTilingOptions options, @@ -52,8 +53,7 @@ if (failed(filter.checkAndNotify(rewriter, op))) return failure(); - FailureOr tilingResult = - returningMatchAndRewrite(op, rewriter); + auto tilingResult = returningMatchAndRewrite(op, rewriter); if (failed(tilingResult)) { return failure(); } @@ -65,6 +65,50 @@ linalg::LinalgTransformationFilter filter; }; +/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern +/// (that tiles and fuses operations using the `TilingInterface` with `scf.for` +/// ops for iterating over the tiles) while using a `filter` to avoid recursive +/// application. +struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter + : public scf::TileConsumerAndFuseProducersUsingSCFForOp { + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( + MLIRContext *context, scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options, + benefit), + filter(filter) {} + + /// Construct a generic pattern applied to `opName`. + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( + StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options, + benefit), + filter(filter) {} + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter); + if (failed(tileAndFuseResult)) { + return failure(); + } + filter.replaceLinalgTransformationFilter( + rewriter, tileAndFuseResult->tiledAndFusedOps.front()); + return success(); + } + +private: + linalg::LinalgTransformationFilter filter; +}; + +/// Test pass for testing the use of `TilingInterface`. struct TestTilingInterfacePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) @@ -82,29 +126,63 @@ return "Test tiling using TilingInterface"; } + Option testTiling{ + *this, "tile-using-scf-for", + llvm::cl::desc( + "Test tiling using TilingInterface with scf.for operations"), + llvm::cl::init(false)}; + + Option testTileConsumerAndFuseProducer{ + *this, "tile-consumer-and-fuse-producer-using-scf-for", + llvm::cl::desc("Test tile and fuse transformation using TilingInterface " + "with scf.for operations"), + llvm::cl::init(false)}; + void runOnOperation() override; + +private: + void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns); }; } // namespace -static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { - auto addPatternForTiling = [&](ArrayRef tileSizes, - StringRef filterName) { - scf::SCFTilingOptions tilingOptions; - tilingOptions.setTileSizes(tileSizes); - linalg::LinalgTransformationFilter filter( - StringAttr::get(context, filterName), - StringAttr::get(context, "tiled")); - patterns.add(context, tilingOptions, - filter); - }; - // 1. Tiling M and N dims of `linalg.matmul` on tensors. - addPatternForTiling({10, 20}, "simple_gemm"); - // 2. Tiling M, N and K of `linalg.matmul` on buffers. - addPatternForTiling({10, 20, 30}, "simple_gemm_memref"); - // 3. Tiling 3D parallel generic op which implements a transpose - addPatternForTiling({10, 0, 20}, "parallel_generic_transpose"); - // 4. Tiling 2D conv op. - addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv"); +template +static void +addPatternForTiling(MLIRContext *context, ArrayRef tileSizes, + StringRef filterName, RewritePatternSet &patterns) { + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes); + linalg::LinalgTransformationFilter filter( + StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); + patterns.add(context, tilingOptions, filter); +} + +void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, + RewritePatternSet &patterns) { + if (testTiling) { + // 1. Tiling M and N dims of `linalg.matmul` on tensors. + addPatternForTiling( + context, {10, 20}, "simple_gemm", patterns); + // 2. Tiling M, N and K of `linalg.matmul` on buffers. + addPatternForTiling( + context, {10, 20, 30}, "simple_gemm_memref", patterns); + // 3. Tiling 3D parallel generic op which implements a transpose + addPatternForTiling( + context, {10, 0, 20}, "parallel_generic_transpose", patterns); + // 4. Tiling 2D conv op. + addPatternForTiling( + context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns); + return; + } + if (testTileConsumerAndFuseProducer) { + // 1. Tile and fuse of gemm with bias-add operation. + addPatternForTiling< + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( + context, {10, 20}, "fusion", patterns); + addPatternForTiling< + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( + context, {10}, "gemm_fusion", patterns); + return; + } } void TestTilingInterfacePass::runOnOperation() { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1884,6 +1884,7 @@ ":SCFUtils", ":Support", ":TensorDialect", + ":TensorTransforms", ":TilingInterface", ":Transforms", "//llvm:Support", @@ -4997,6 +4998,7 @@ ":SCFDialect", ":TensorDialect", ":TensorPassIncGen", + ":TilingInterface", ":Transforms", "//llvm:Support", ],