diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -544,8 +544,14 @@ tileSizes.assign(ts.begin(), ts.end()); return *this; } + /// Tile interchange used to permute the tile loops. SmallVector tileInterchange; + LinalgTilingAndFusionOptions &setInterchane(ArrayRef interchange) { + tileInterchange.assign(interchange.begin(), interchange.end()); + return *this; + } + /// When specified, specifies distribution of generated tile loops to /// processors. Optional tileDistribution = None; @@ -554,6 +560,61 @@ tileDistribution = std::move(distributionOptions); return *this; } + + /// When specified, this returns from the tiled loops, the tensor values + /// that can be used as replacements for the operations that are fused into + /// the consumer tiled loops. Note that fusion can lead to the same tile of + /// the producer being recomputed for producing different tiles of the + /// consumer. In such situations, it is tricky to to contruct the return value + /// for the fused ops within the loop. So this option is to be used with care + /// only when the caller knows that the fused op is not recomputed. For + /// example, + /// + /// ```mlir + /// %C = linalg.matmul ins(%A, %B : tensor, tensor) + /// outs (%init1 : tensor) -> tensor + /// %E = linalg.matmul ins(%C, %D : tensor, tensor) + /// outs (%init2 : tensor) -> tensor + /// ``` + /// + /// depending the tiling chosen for the consumer `linalg.matmul` can be tiled + /// either along two outer loop dimensions (`i` and `j` loops), or just the + /// first outer loop dimension (`i`) loop could be tiled. Fusing the producer + /// results in either + /// + /// ```mlir + /// %l0 = scf.for %iv0 = ... iter_args(%arg0 = %init2) { + /// %l1 = scf.for %iv1 = ... iter_args(%arg1 = %arg0) { + /// %C = linalg.matmul + /// %E = linalg.matmul ins(%C, ...) ... + /// scf.yield %E + /// } + /// } + /// ``` + /// + /// or + /// + /// ```mlir + /// %l0 = scf.for %iv0 = ... iter_args(%arg0 = %init2) { + /// %C = linalg.matmul ... + /// %E = linalg.matmul ins(%C, ...) ... + /// scf.yield %E + /// } + /// } + /// ``` + /// + /// In the first case, the same producer is recomputed along all iterations of + /// the inner loop. In the second case, there is no recomputation. In such + /// cases the value of the producer computed within the tile loop can be + /// returned as well so that it could replace the uses of the untiled producer + /// op. Setting the option to `true` assumes that the producer tiles are not + /// recomputed after fusion and the caller is expected to ensure this + /// assumption holds. + bool returnFusedOpValues = false; + LinalgTilingAndFusionOptions &setReturnFusedOpValues(bool v) { + returnFusedOpValues = v; + return *this; + } }; struct LinalgTilingOptions { diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -343,10 +343,19 @@ /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns /// the fused producer or fails if fusion is not possible. - FailureOr fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand); + FailureOr fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand, + bool replaceFusedOpExternalUses); - /// Returns the replacement results for the original untiled root operation. - ValueRange getRootOpReplacementResults(); + /// For the `untiledOp` that is fused with the root op, `clonedOp` being the + /// fused operation, modify the tiled loop nest to yield a value that can + /// be used to replace `untiledOp`. + void yieldFusedValues(OpBuilder &b, LinalgOp untiledOp, LinalgOp clonedOp, + ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides); + + /// Returns the replacement results for the original untiled operation. + Value getUntiledOpResultReplacement(Value v); /// Returns the tiled root operation. LinalgOp getRootOp() { return rootOp; } @@ -354,6 +363,10 @@ /// Returns the tiled root operation and the fused producers. SmallVector getAllTiledAndFusedOps(); + /// Returns the ops that were tiled and fused into the loop (returns the + /// original ops) + ArrayRef getFusedOps() { return fusedOps; } + /// Returns the loop ops generated from tiling. ArrayRef getLoopOps() { return tileLoopOps; } @@ -386,6 +399,9 @@ LinalgOp rootOp; SmallVector tileLoopOps; DenseMap> tiledRootAndFusedOpsLoops; + + SmallVector fusedOps; + DenseMap replacements; }; /// Tiles `consumerOp` and fuses its dependencies if possible. Uses the @@ -394,7 +410,8 @@ FailureOr tileConsumerAndFuseProducers( OpBuilder &b, LinalgOp consumerOp, ArrayRef tileSizes, ArrayRef tileInterchange, - const Optional &tileDistribution); + const Optional &tileDistribution, + bool returnFusedOpValues = false); //===----------------------------------------------------------------------===// // Generic op region utilities diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -305,6 +305,9 @@ // or pattern. if (!isEmpty()) rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); + fusedOps.push_back(rootOp); + for (auto result : llvm::enumerate(rootOp->getResults())) + replacements[result.value()] = tiledRootOp->tensorResults[result.index()]; // Transfer the stored `rootOp` loop dimensions if it has been tiled before. if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) { @@ -325,8 +328,99 @@ return success(); } +void TileLoopNest::yieldFusedValues(OpBuilder &b, LinalgOp untiledOp, + LinalgOp tiledOp, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + if (tileLoopOps.empty()) + return; + + SmallVector newIterOperands; + SmallVector yieldedResults; + SmallVector returnedFusedOpResults; + for (auto result : llvm::enumerate(untiledOp->getResults())) { + // Check if any of the uses are outside the tile loop nest. Dim + // uses are not treated as "real" uses, further the dim ops could be used + // to compute the bounds of the loop itself. They are expected to be + // resolved using `ReifyRankedShapedTypeOpInterface`. + if (llvm::all_of(result.value().getUsers(), [&](Operation *user) { + return tileLoopOps.front()->isAncestor(user) || + isa(user) || llvm::is_contained(fusedOps, user); + ; + })) + continue; + + unsigned resultNumber = result.value().getResultNumber(); + // 1. Replace all existing uses of untiled op in the iter_args of the tile + // loops with the `outs` of the untiled operation. + OpOperand *untiledOpOutOperand = untiledOp.getOutputOperand(resultNumber); + tileLoopOps.front()->replaceUsesOfWith(result.value(), + untiledOpOutOperand->get()); + + // 2. The `outs` ops of the untiled op are the "new" iter operands to be + // added. + newIterOperands.push_back(untiledOpOutOperand->get()); + yieldedResults.push_back(tiledOp->getResult(resultNumber)); + returnedFusedOpResults.push_back(result.value()); + } + + // The inner most loop body needs to be modified to add `tensor.insert_slice`s + // to reconstruct the full tensor using the values produced from the tiled + // ops. + scf::ForOp outerMostLoop = tileLoopOps.front(); + scf::ForOp innerMostLoop = tileLoopOps.back(); + NewYieldValueFn innerFn = [&](OpBuilder &innerBuilder, Location innerLoc, + ArrayRef innerNewBBArgs) { + SmallVector newYieldVals; + for (auto yieldedResult : llvm::enumerate(yieldedResults)) { + newYieldVals.push_back(innerBuilder.create( + tiledOp->getLoc(), yieldedResult.value(), + innerNewBBArgs[yieldedResult.index()], offsets, sizes, strides)); + } + return newYieldVals; + }; + scf::ForOp newInnerMostLoop = + replaceLoopWithNewYields(b, innerMostLoop, newIterOperands, innerFn); + tileLoopOps.back() = newInnerMostLoop; + + // Propagate the newly yielded values back up the loop nest. + for (unsigned loopDepth : + llvm::reverse(llvm::seq(0, tileLoopOps.size() - 1))) { + NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location innerLoc, + ArrayRef innerNewBBArgs) { + return llvm::to_vector( + llvm::map_range(tileLoopOps[loopDepth + 1].getResults().take_back( + newIterOperands.size()), + [](OpResult r) -> Value { return r; })); + }; + tileLoopOps[loopDepth] = replaceLoopWithNewYields(b, tileLoopOps[loopDepth], + newIterOperands, fn); + } + + // The `replacements` uses values returned by the tile loop nest. Update these + // to use values from the new loop nests created here. Build a map from result + // of the previous outermost loop to the new outermost loop. + llvm::SmallDenseMap resultValRemap; + scf::ForOp newOuterMostLoop = tileLoopOps.front(); + for (auto origResult : llvm::enumerate(outerMostLoop->getResults())) { + resultValRemap[origResult.value()] = + newOuterMostLoop->getResult(origResult.index()); + } + for (auto &replacement : replacements) { + auto newReplacement = resultValRemap.lookup(replacement.second); + if (newReplacement) + replacement.second = newReplacement; + } + for (auto fusedOpResult : llvm::enumerate(returnedFusedOpResults)) { + replacements[fusedOpResult.value()] = newOuterMostLoop.getResult( + outerMostLoop.getNumResults() + fusedOpResult.index()); + } +} + FailureOr TileLoopNest::fuseProducer(OpBuilder &b, - OpOperand *consumerOpOperand) { + OpOperand *consumerOpOperand, + bool returnFusedOpValues) { // Check if the consumer has been tiled before. For example, it may not have // been tiled if the outermost tile loop is a reduction loop. if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0) @@ -361,13 +455,18 @@ OpOperand *iterArg = nullptr; auto producerResult = sliceOp.source().dyn_cast(); if (auto bbArg = sliceOp.source().dyn_cast()) { + if (bbArg.getParentBlock() != sliceOp->getBlock()) + return failure(); iterArg = getTiedIterArg(bbArg); // Check the iteration argument may be used to pass in the producer output. if (!iterArg || hasOtherUses(bbArg, sliceOp)) return failure(); producerResult = iterArg->get().dyn_cast(); } - if (!producerResult || !isa(producerResult.getOwner())) + if (!producerResult) + return failure(); + LinalgOp producer = dyn_cast(producerResult.getOwner()); + if (!producer) return failure(); // Compute the tiled producer slice dimensions given the tiled consumer loops. @@ -385,25 +484,30 @@ getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices, tiledProducerLoopIndices, iterArg); tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices; + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(clonedOp); // Cast the `clonedOp` result to gap type mismatches before canonicalization. Type consumerOperandType = consumerOpOperand->get().getType(); Value newResult = clonedOp->getResult(producerResult.getResultNumber()); if (newResult.getType() != consumerOperandType) { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointAfter(clonedOp); newResult = b.create(producerResult.getLoc(), consumerOperandType, newResult); } // Replace the `sliceOp` uses except for the `clonedOp` output uses. sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp); + + if (returnFusedOpValues) { + fusedOps.push_back(producer); + yieldFusedValues(b, producer, clonedOp, sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); + } return clonedOp; } -ValueRange TileLoopNest::getRootOpReplacementResults() { - assert(!isEmpty() && "expect tile loop nest to be non-empty"); - return tileLoopOps.front()->getOpResults(); +Value TileLoopNest::getUntiledOpResultReplacement(Value v) { + return replacements.lookup(v); } SmallVector TileLoopNest::getAllTiledAndFusedOps() { @@ -424,7 +528,8 @@ FailureOr mlir::linalg::tileConsumerAndFuseProducers( OpBuilder &b, LinalgOp consumerOp, ArrayRef tileSizes, ArrayRef tileInterchange, - const Optional &tileDistribution) { + const Optional &tileDistribution, + bool returnFusedOpValues) { assert(tileSizes.size() == tileInterchange.size() && "expect the number of tile sizes and interchange dims to match"); assert(isPermutation(tileInterchange) && @@ -444,11 +549,12 @@ int64_t split = std::distance(iterTypes.begin(), it); // Helper to fuse the producers greedily using a queue of fusion candidates. - auto fuseProducersGreedily = [&](ArrayRef operands) { + auto fuseProducersGreedily = [&](ArrayRef operands, + bool returnFusedOpValues) { SmallVector candidates(operands.begin(), operands.end()); while (!candidates.empty()) { - FailureOr fusedProducer = - tileLoopNest.fuseProducer(b, candidates.pop_back_val()); + FailureOr fusedProducer = tileLoopNest.fuseProducer( + b, candidates.pop_back_val(), returnFusedOpValues); if (failed(fusedProducer)) continue; candidates.append(fusedProducer->getInputAndOutputOperands()); @@ -462,7 +568,12 @@ if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange, tileDistribution))) return failure(); - fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); + if (returnFusedOpValues) { + fuseProducersGreedily(tileLoopNest.getRootOp().getInputAndOutputOperands(), + true); + } else { + fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands(), false); + } // Tile the remaining loops and fuse the input operands. SmallVector innerTileSizes; @@ -471,7 +582,7 @@ if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange, tileDistribution))) return failure(); - fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); + fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands(), false); // Exit if the tile loop nest is empty since all tile sizes are zero. if (tileLoopNest.isEmpty()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -601,7 +601,7 @@ rootOp.getNumLoops()); SmallVector rootInterchange = options.tileInterchange.empty() - ? llvm::to_vector<6>(llvm::seq(0, rootOp.getNumLoops())) + ? llvm::to_vector(llvm::seq(0, rootOp.getNumLoops())) : SmallVector(options.tileInterchange.begin(), options.tileInterchange.begin() + rootOp.getNumLoops()); @@ -619,19 +619,45 @@ op, "expect the tile interchange permutes the root loops"); // Tile `rootOp` and fuse its producers. - FailureOr tileLoopNest = - tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, - rootInterchange, options.tileDistribution); + FailureOr tileLoopNest = tileConsumerAndFuseProducers( + rewriter, rootOp, rootTileSizes, rootInterchange, + options.tileDistribution, options.returnFusedOpValues); + if (failed(tileLoopNest)) return rewriter.notifyMatchFailure( op, "tileConsumerAndFuseProducers failed unexpectedly"); - // Replace all uses of the tiled loop operation. - rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); + // Replace uses of the tiled operation(s) which are + // - Not within the loop itself (that creates an SSA use-def violation) + // - Ignore `tensor.dim` uses. These `tensor.dim` operations might be + // used in computation of the loop bounds themselves. Replacing these uses + // causes use-def violations. These uses should be handled using + // `ReifyRankedShapedTypeOpInterface`. + scf::ForOp outerMostLoop = tileLoopNest->getLoopOps().front(); + auto useOutsideLoopFn = [&outerMostLoop](OpOperand &use) { + Operation *user = use.getOwner(); + return !isa(user) && + !outerMostLoop->isAncestor(use.getOwner()); + }; + for (Value rootResult : rootOp->getResults()) { + if (Value replacement = + tileLoopNest->getUntiledOpResultReplacement(rootResult)) + rootResult.replaceUsesWithIf(replacement, useOutsideLoopFn); + } + if (options.returnFusedOpValues) { + for (auto fusedOp : tileLoopNest->getFusedOps()) { + for (auto fusedOpResult : fusedOp->getResults()) { + if (Value replacement = + tileLoopNest->getUntiledOpResultReplacement(fusedOpResult)) + fusedOpResult.replaceUsesWithIf(replacement, useOutsideLoopFn); + } + } + } // Apply the filter if specified. - for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) + for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) { filter.replaceLinalgTransformationFilter(rewriter, linalgOp); + } return tileLoopNest; } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -1,193 +1,210 @@ -// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-tile-and-fuse-on-tensors -cse --split-input-file | FileCheck %s -module { - func.func @matmul_fusion(%A: tensor, %B: tensor, - %AB_init: tensor, %C: tensor, - %ABC_init: tensor) -> tensor { - %AB = linalg.matmul ins(%A, %B : tensor, tensor) - outs(%AB_init : tensor) -> tensor // - %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} - ins(%AB, %C : tensor, tensor) - outs(%ABC_init : tensor) -> tensor // - return %ABC : tensor - } +func.func @matmul_bias_add(%lhs : tensor, %rhs : tensor, %bias : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %lhs, %c0 : tensor + %d1 = tensor.dim %rhs, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor + %matmul = linalg.matmul ins(%lhs, %rhs : tensor, tensor) + outs(%fill : tensor) -> tensor + %bias_add = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"], + __internal_linalg_transform__ = "return_fused_values"} + ins(%matmul, %bias : tensor, tensor) + outs(%init : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b3 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + return %matmul, %bias_add : tensor, tensor } -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)> +// CHECK-LABEL: func @matmul_bias_add +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[OUTER:.+]]:2 = scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step +// CHECK-SAME: iter_args(%[[OUTER_ITER0:[a-zA-Z0-9]+]] = %[[INIT]], +// CHECK-SAME: %[[OUTER_ITER1:[a-zA-Z0-9]+]] = %[[INIT]]) +// CHECK: %[[INNER:.+]]:2 = scf.for %[[IV1:.+]] = %{{.+}} to %[[UB1:.+]] step +// CHECK-SAME: iter_args(%[[INNER_ITER0:[a-zA-Z0-9]+]] = %[[OUTER_ITER0]], +// CHECK-SAME: %[[INNER_ITER1:[a-zA-Z0-9]+]] = %[[OUTER_ITER1]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[INNER_ITER1]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[FILL_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[OUTS_TILE]] : +// CHECK-DAG: %[[MATMUL_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[FILL_TILE:.+]] : +// CHECK-DAG: %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]] +// CHECK-DAG: %[[OUTS2_TILE:.+]] = tensor.extract_slice %[[INNER_ITER0]][%[[IV0]], %[[IV1]]] +// CHECK: %[[ROOT_TILE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[MATMUL_TILE]], %[[BIAS_TILE]] : +// CHECK-SAME: outs(%[[OUTS2_TILE]] : +// CHECK-DAG: %[[ROOT_INSERT:.+]] = tensor.insert_slice %[[ROOT_TILE]] into %[[INNER_ITER0]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[MATMUL_INSERT:.+]] = tensor.insert_slice %[[MATMUL_TILE]] into %[[INNER_ITER1]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[ROOT_INSERT]], %[[MATMUL_INSERT]] +// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1 +// CHECK: return %[[OUTER]]#1, %[[OUTER]]#0 -// CHECK: func @matmul_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor +// ----- -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = -// CHECK-SAME: %[[C0]] to %[[M]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor) { -// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]]] -// CHECK: %[[N3:.+]] = tensor.dim %[[ARG6]], %[[C1]] -// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]])[%[[M]], %[[M]]] -// CHECK: %[[N1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[N1]]] -// CHECK: %[[N2_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] -// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[N2_2]]] -// CHECK: %[[LHS:.+]] = linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" -// CHECK-SAME: ins(%[[ST_ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) -// CHECK: %[[N2:.+]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[N3_2:.+]] = tensor.dim %[[ARG3]], %[[C1]] -// CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = -// CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]] -// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ST_ARG6]]) -> (tensor) { -// CHECK: %[[YIELD1:.+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = -// CHECK-SAME: %[[C0]] to %[[N2]] step %[[C16]] -// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor) { -// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]] -// CHECK: %[[ST_LHS:.+]] = tensor.extract_slice %[[LHS]][0, %[[IV2]]] -// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N2]]] -// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]] -// CHECK: %[[ST_ARG3:.+]] = tensor.extract_slice %[[ARG3]][%[[IV2]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_N2]], %[[TILE_N3]]] -// CHECK: %[[M_4:.+]] = tensor.dim %[[ARG10]], %[[C0]] -// CHECK: %[[ST_ARG4:.+]] = tensor.extract_slice %[[ARG10]][0, %[[IV1]]] -// CHECK-SAME: [%[[M_4]], %[[TILE_N3]]] -// CHECK: %[[ST_RESULT:.+]] = linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion" -// CHECK-SAME: ins(%[[ST_LHS]], %[[ST_ARG3]] -// CHECK-SAME: : tensor, tensor) -// CHECK-SAME: outs(%[[ST_ARG4]] : tensor) -// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[ST_RESULT]] -// CHECK-SAME: into %[[ARG10]][0, %[[IV1]]] [%[[M_4]], %[[TILE_N3]]] -// CHECK: scf.yield %[[UPDATE1]] -// CHECK: } -// CHECK: scf.yield %[[YIELD1]] -// CHECK: } -// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[YIELD0]] into -// CHECK-SAME: %[[ARG6]][%[[IV0]], 0] [%[[TILE_M_2]], %[[N3]]] -// CHECK: scf.yield %[[UPDATE0]] -// CHECK: } -// CHECK: return %[[RESULT]] +func.func @matmul_lhs_fusion(%A: tensor, %B: tensor, + %AB_init: tensor, %C: tensor, %ABC_init: tensor) + -> (tensor, tensor) { + %AB = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%AB_init : tensor) -> tensor // + %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} + ins(%AB, %C : tensor, tensor) + outs(%ABC_init : tensor) -> tensor // + return %AB, %ABC : tensor, tensor +} +// CHECK-LABEL: func @matmul_lhs_fusion +// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[B:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[AB_INIT:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[C:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ABC_INIT:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[LOOP:.+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] +// CHECK-SAME: iter_args(%[[ITER0:[a-zA-Z0-9]+]] = %[[ABC_INIT]], +// CHECK-SAME: %[[ITER1:[a-zA-Z0-9]+]] = %[[AB_INIT]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[A]][%[[IV0]], 0] +// CHECK-DAG: %[[OUTS0_TILE:.+]] = tensor.extract_slice %[[ITER1]][%[[IV0]], 0] +// CHECK-DAG: %[[GEMM0_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[B]] : +// CHECK-SAME: outs(%[[OUTS0_TILE]] : +// CHECK-DAG: %[[OUTS1_TILE:.+]] = tensor.extract_slice %[[ITER0]][%[[IV0]], 0] +// CHECK-DAG: %[[GEMM1_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[C]] : +// CHECK-SAME: outs(%[[OUTS1_TILE:.+]] : +// CHECK-DAG: %[[GEMM1_INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITER0]][%[[IV0]], 0] +// CHECK-DAG: %[[GEMM0_INSERT:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITER1]][%[[IV0]], 0] +// CHECK: scf.yield %[[GEMM1_INSERT]], %[[GEMM0_INSERT]] +// CHECK: return %[[LOOP]]#1, %[[LOOP]]#0 // ----- -module { - func.func @matmul_plus_matmul(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor{ - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %arg2, %c0 : tensor - %1 = tensor.dim %arg2, %c1 : tensor - %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor - %3 = tensor.dim %2, %c0 : tensor - %4 = tensor.dim %2, %c1 : tensor - %5 = linalg.init_tensor [%3, %4] : tensor - %6 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - __internal_linalg_transform__ = "transpose_fusion"} - ins(%2, %2 : tensor, tensor) - outs(%5 : tensor) { - ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : - %7 = arith.addf %arg3, %arg4 : f32 - linalg.yield %7 : f32 - } -> tensor - return %6 : tensor - } +func.func @matmul_rhs_fusion(%A: tensor, %B: tensor, + %AB_init: tensor, %C: tensor, %ABC_init: tensor) + -> (tensor, tensor) { + %AB = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%AB_init : tensor) -> tensor // + %ABC = linalg.matmul {__internal_linalg_transform__ = "rhs_fusion"} + ins(%C, %AB : tensor, tensor) + outs(%ABC_init : tensor) -> tensor // + return %AB, %ABC : tensor, tensor } -// CHECK: func @matmul_plus_matmul -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] -// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) -// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] -// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] -// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK: %[[LHS:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] -// CHECK-SAME: : tensor, tensor) -// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) -// CHECK: %[[ST_RESULT:.+]] = linalg.generic -// CHECK-SAME: ins(%[[LHS]] : tensor) -// CHECK-SAME: outs(%[[ST_ARG6]] : tensor) -// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] -// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] -// CHECK: scf.yield %[[UPDATE]] -// CHECK: scf.yield %[[YIELD]] -// CHECK: return %[[RESULT]] +// CHECK-LABEL: func @matmul_rhs_fusion +// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[B:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[AB_INIT:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[C:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ABC_INIT:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[LOOP:.+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] +// CHECK-SAME: iter_args(%[[ITER0:[a-zA-Z0-9]+]] = %[[ABC_INIT]], +// CHECK-SAME: %[[ITER1:[a-zA-Z0-9]+]] = %[[AB_INIT]]) +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[B]][0, %[[IV0]]] +// CHECK-DAG: %[[OUTS0_TILE:.+]] = tensor.extract_slice %[[ITER1]][0, %[[IV0]]] +// CHECK-DAG: %[[GEMM0_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[A]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[OUTS0_TILE]] : +// CHECK-DAG: %[[OUTS1_TILE:.+]] = tensor.extract_slice %[[ITER0]][0, %[[IV0]]] +// CHECK-DAG: %[[GEMM1_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[C]], %[[GEMM0_TILE]] : +// CHECK-SAME: outs(%[[OUTS1_TILE:.+]] : +// CHECK-DAG: %[[GEMM1_INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITER0]][0, %[[IV0]]] +// CHECK-DAG: %[[GEMM0_INSERT:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITER1]][0, %[[IV0]]] +// CHECK: scf.yield %[[GEMM1_INSERT]], %[[GEMM0_INSERT]] +// CHECK: return %[[LOOP]]#1, %[[LOOP]]#0 // ----- -module { - func.func @matmul_out_fusion(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor - %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"} - ins(%arg1, %arg2 : tensor, tensor) - outs(%0 : tensor) -> tensor - return %1 : tensor - } +func.func @matmul_out_fusion(%arg0: tensor, %arg1: tensor) -> tensor { + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor + %gemm = linalg.matmul {__internal_linalg_transform__ = "matmul_outs_fusion"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%fill : tensor) -> tensor + return %gemm : tensor } - -// CHECK-LABEL: func @matmul_out_fusion( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[C0:.*]] = arith.constant 0.0{{.*}} : f32 -// CHECK-NOT: fill -// CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor) { -// CHECK: scf.for %[[J:.*]] -// CHECK: %[[ST:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[ST_FILL:.*]] = linalg.fill -// CHECK-SAME: {__internal_linalg_transform__ = "after_out_fusion_producer"} -// CHECK-SAME: ins(%[[C0]] : f32) outs(%[[ST]] : tensor) -> tensor -// CHECK: %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor) { -// CHECK-NOT: fill -// CHECK: %[[ST_FILL_SUB:.*]] = tensor.extract_slice %[[BB]][0, 0] -// CHECK: %[[ST_MM_SUB:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor, tensor) outs(%[[ST_FILL_SUB]] : tensor) -> tensor -// CHECK: %[[ST_MM:.*]] = tensor.insert_slice %[[ST_MM_SUB]] into %[[BB]] -// CHECK: scf.yield %[[ST_MM]] : tensor -// CHECK: %[[MM:.*]] = tensor.insert_slice %[[ST_MM_RES]] into {{.*}} -// CHECK: scf.yield %[[MM]] : tensor +// CHECK-LABEL: func @matmul_out_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[OUTER:.+]] = scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step +// CHECK-SAME: iter_args(%[[OUTER_ITER:[a-zA-Z0-9]+]] = %[[INIT]]) +// CHECK: %[[INNER:.+]] = scf.for %[[IV1:.+]] = %{{.+}} to %[[UB1:.+]] step +// CHECK-SAME: iter_args(%[[INNER_ITER:[a-zA-Z0-9]+]] = %[[OUTER_ITER]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[INNER_ITER]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[FILL_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[OUTS_TILE]] : +// CHECK-DAG: %[[MATMUL_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[FILL_TILE:.+]] : +// CHECK-DAG: %[[MATMUL_INSERT:.+]] = tensor.insert_slice %[[MATMUL_TILE]] into %[[INNER_ITER]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[MATMUL_INSERT]] +// CHECK: scf.yield %[[INNER]] +// CHECK: return %[[OUTER]] // ----- -module { - func.func @generic_plus_matmul(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %0 = linalg.generic { - indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>], +func.func @generic_plus_matmul(%arg0: tensor, %arg1: tensor, %arg2 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %bias = linalg.generic { + indexing_maps = [affine_map<(m, n) -> (n)>, affine_map<(m, n) -> (m, n)>], iterator_types = ["parallel", "parallel"]} - ins(%c0 : f32) - outs(%arg0: tensor) { - ^bb(%0: f32, %1: f32) : - linalg.yield %0 : f32 + ins(%arg2 : tensor) outs(%init: tensor) { + ^bb(%b0: f32, %b1: f32) : + linalg.yield %b0 : f32 } -> tensor - %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"} - ins(%arg1, %arg2 : tensor, tensor) - outs(%0 : tensor) -> tensor - return %1 : tensor - } + %gemm = linalg.matmul {__internal_linalg_transform__ = "matmul_outs_fusion"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%bias : tensor) -> tensor + return %bias, %gemm : tensor, tensor } +// CHECK-LABEL: func @generic_plus_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[OUTER:.+]]:2 = scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step +// CHECK-SAME: iter_args(%[[OUTER_ITER0:[a-zA-Z0-9]+]] = %[[INIT]], +// CHECK-SAME: %[[OUTER_ITER1:[a-zA-Z0-9]+]] = %[[INIT]]) +// CHECK: %[[INNER:.+]]:2 = scf.for %[[IV1:.+]] = %{{.+}} to %[[UB1:.+]] step +// CHECK-SAME: iter_args(%[[INNER_ITER0:[a-zA-Z0-9]+]] = %[[OUTER_ITER0]], +// CHECK-SAME: %[[INNER_ITER1:[a-zA-Z0-9]+]] = %[[OUTER_ITER1]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[INNER_ITER0]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]] +// CHECK: %[[BCAST_TILE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[BIAS_TILE]] : +// CHECK-SAME: outs(%[[OUTS_TILE]] : +// CHECK-DAG: %[[MATMUL_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[BCAST_TILE:.+]] : +// CHECK-DAG: %[[MATMUL_INSERT:.+]] = tensor.insert_slice %[[MATMUL_TILE]] into %[[INNER_ITER0]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[BCAST_INSERT:.+]] = tensor.insert_slice %[[BCAST_TILE]] into %[[INNER_ITER1]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[MATMUL_INSERT]], %[[BCAST_INSERT]] +// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1 +// CHECK: return %[[OUTER]]#1, %[[OUTER]]#0 diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRLinalgTestPasses TestLinalgCodegenStrategy.cpp TestLinalgElementwiseFusion.cpp + TestLinalgFusionOnTensors.cpp TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp TestLinalgTransforms.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionOnTensors.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionOnTensors.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionOnTensors.cpp @@ -0,0 +1,99 @@ +//===- TestLinalgFusionOnTensorsPass.cpp - Test Linalg fusion patterns ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing Linalg fusion patterns on tensors. +// Specifically the `tileConsumerAndFuseProducers` method. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +struct TestLinalgFusionOnTensorsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionOnTensorsPass) + + StringRef getArgument() const final { + return "test-linalg-tile-and-fuse-on-tensors"; + } + + StringRef getDescription() const final { + return "Test Linalg tiling and fusion on tensor operations."; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + TestLinalgFusionOnTensorsPass() = default; + TestLinalgFusionOnTensorsPass(const TestLinalgFusionOnTensorsPass &pass) {} + + void runOnOperation() override { + MLIRContext *context = &this->getContext(); + func::FuncOp funcOp = this->getOperation(); + + RewritePatternSet fusionPatterns(context); + + fusionPatterns.insert( + context, + LinalgTilingAndFusionOptions() + .setTileSizes({10, 20}) + .setReturnFusedOpValues(true), + LinalgTransformationFilter( + StringAttr::get(context, "return_fused_values"), + StringAttr::get(context, "fused_ops"))); + + fusionPatterns.insert( + context, + LinalgTilingAndFusionOptions() + .setTileSizes({10, 0, 0}) + .setReturnFusedOpValues(true), + LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"), + StringAttr::get(context, "fused_ops"))); + + fusionPatterns.insert( + context, + LinalgTilingAndFusionOptions() + .setTileSizes({0, 20, 0}) + .setReturnFusedOpValues(true), + LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"), + StringAttr::get(context, "fused_ops"))); + + fusionPatterns.insert( + context, + LinalgTilingAndFusionOptions() + .setTileSizes({10, 20, 0}) + .setReturnFusedOpValues(true), + LinalgTransformationFilter( + StringAttr::get(context, "matmul_outs_fusion"), + StringAttr::get(context, "fused_ops"))); + + (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgFusionOnTensorsPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -86,6 +86,7 @@ void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); void registerTestLinalgElementwiseFusion(); +void registerTestLinalgFusionOnTensorsPass(); void registerTestLinalgFusionTransforms(); void registerTestLinalgTensorFusionTransforms(); void registerTestLinalgTiledLoopFusionTransforms(); @@ -182,6 +183,7 @@ mlir::test::registerTestInterfaces(); mlir::test::registerTestLinalgCodegenStrategy(); mlir::test::registerTestLinalgElementwiseFusion(); + mlir::test::registerTestLinalgFusionOnTensorsPass(); mlir::test::registerTestLinalgFusionTransforms(); mlir::test::registerTestLinalgTensorFusionTransforms(); mlir::test::registerTestLinalgTiledLoopFusionTransforms();