diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -310,4 +310,30 @@ }]; } +def LoopFuseSibling : Op]> { + let summary = "Fuse a loop into another loop, assuming they are independent."; + + let description = [{ + Fuses the `loop` into the `sibling` assuming they are independent of + each other, i.e. neither is an ancestor of another. + + #### Return modes + + Returns a handle to the fused loop. + This operation consumes the both the given handles. + }]; + + let arguments = (ins TransformHandleTypeInterface:$loop, + TransformHandleTypeInterface:$sibling); + let results = (outs TransformHandleTypeInterface:$fused_loop); + let assemblyFormat = "$loop `into` $sibling attr-dict " + " `:` functional-type(operands, results)"; + + let builders = [ + OpBuilder<(ins "Value":$loop, "Value":$fused_loop)> + ]; +} + #endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -185,6 +185,11 @@ void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, scf::ForOp root); +/// Given two scf.forall loops, fuses them into a single loop. Assumes that +/// the given loops are "siblings", i.e. neither of them is an ancestor of the +/// other. +scf::ForallOp fuseSiblingForallLoops(scf::ForallOp loop, scf::ForallOp sibling); + } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -318,6 +318,78 @@ modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// LoopFuseSibling +//===----------------------------------------------------------------------===// + +/// Check if `op` and `sibling` are siblings i.e. if neither is an ancestor +/// of the other. +static bool isOpSibling(Operation *op, Operation *sibling) { + return !sibling->isAncestor(op) && !op->isAncestor(sibling); +} + +/// Check if `loop` can be fused into `sibling`. We simply check if the loop +/// bounds are the same. +static bool isFusionLegal(Operation *loop, Operation *sibling) { + // TODO: Add fusion for more operations. Currently, we handle only scf.forall. + if (!isa(loop) || !isa(sibling)) + return false; + + // Check if the loop bounds are the same. + auto loopBounds = cast(loop).getMixedLowerBound(); + auto siblingBounds = cast(sibling).getMixedLowerBound(); + + // Check if the optional mapping is same. + auto loopMapping = cast(loop).getMapping(); + auto siblingMapping = cast(sibling).getMapping(); + + return loopBounds == siblingBounds && loopMapping == siblingMapping; +} + +/// Fuse `loop` into `sibling` assuming they are siblings. +static Operation *fuseSiblings(Operation *loop, Operation *sibling, + RewriterBase &rewriter) { + // TODO: Add fusion for more operations. Currently, we handle only scf.forall. + if (!isa(loop) || !isa(sibling)) + return nullptr; + + scf::ForallOp loopOp = cast(loop); + scf::ForallOp siblingOp = cast(sibling); + return fuseSiblingForallLoops(loopOp, siblingOp); +} + +DiagnosedSilenceableFailure +transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto loopOps = state.getPayloadOps(getLoop()); + auto siblingOps = state.getPayloadOps(getSibling()); + + if (!llvm::hasSingleElement(loopOps) || !llvm::hasSingleElement(siblingOps)) + return emitDefiniteFailure() + << "requires exactly one loop handle (got " + << llvm::range_size(loopOps) << ") and exactly one " + << "sibling handle (got " << llvm::range_size(siblingOps) << ")"; + + Operation *loop = *loopOps.begin(); + Operation *sibling = *siblingOps.begin(); + + // Check if the loop and sibling are siblings. + if (!isOpSibling(loop, sibling)) + return emitDefiniteFailure() << "Operations are not siblings"; + + // Check if the loop and sibling can be fused. + if (!isFusionLegal(loop, sibling)) + return emitDefiniteFailure() << "Operations cannot be fused"; + + Operation *fusedLoop = fuseSiblings(loop, sibling, rewriter); + assert(fusedLoop && "failed to fuse operations"); + + results.set(cast(getFusedLoop()), + SmallVector{fusedLoop}); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -970,3 +970,67 @@ return tileLoops; } + +scf::ForallOp mlir::fuseSiblingForallLoops(scf::ForallOp loop, + scf::ForallOp sibling) { + + OpBuilder b(sibling); + + // Create fused shared_outs. + auto loopSharedOuts = loop.getOutputs(); + auto siblingSharedOuts = sibling.getOutputs(); + SmallVector fusedOuts; + fusedOuts.reserve(loopSharedOuts.size() + siblingSharedOuts.size()); + fusedOuts.append(loopSharedOuts.begin(), loopSharedOuts.end()); + fusedOuts.append(siblingSharedOuts.begin(), siblingSharedOuts.end()); + + // Create a new scf::forall op after the sibling loop. + b.setInsertionPointAfter(sibling); + auto fusedLoop = b.create( + sibling.getLoc(), sibling.getMixedLowerBound(), + sibling.getMixedUpperBound(), sibling.getMixedStep(), fusedOuts, + sibling.getMapping()); + + // Map control operands. + IRMapping fusedMapping; + fusedMapping.map(loop.getInductionVars(), fusedLoop.getInductionVars()); + fusedMapping.map(sibling.getInductionVars(), fusedLoop.getInductionVars()); + + // Map shared outs. + fusedMapping.map( + loop.getOutputBlockArguments(), + fusedLoop.getOutputBlockArguments().slice(0, loopSharedOuts.size())); + fusedMapping.map(sibling.getOutputBlockArguments(), + fusedLoop.getOutputBlockArguments().slice( + loopSharedOuts.size(), siblingSharedOuts.size())); + + // Append everything except the terminator into the fused operation. + b.setInsertionPointToStart(fusedLoop.getBody()); + for (Operation &op : loop.getLoopBody().begin()->without_terminator()) + b.clone(op, fusedMapping); + for (Operation &op : sibling.getLoopBody().begin()->without_terminator()) + b.clone(op, fusedMapping); + + // Fuse the old terminator in_parallel ops into the new one. + auto loopTerm = loop.getTerminator(); + auto siblingTerm = sibling.getTerminator(); + auto fusedTerm = fusedLoop.getTerminator(); + + b.setInsertionPointToStart(fusedTerm.getBody()); + for (Operation &op : loopTerm.getYieldingOps()) + b.clone(op, fusedMapping); + for (Operation &op : siblingTerm.getYieldingOps()) + b.clone(op, fusedMapping); + + // Replace all uses of the old loops with the fused loop. + loop->replaceAllUsesWith( + fusedLoop.getResults().slice(0, loopSharedOuts.size())); + sibling->replaceAllUsesWith(fusedLoop.getResults().slice( + loopSharedOuts.size(), siblingSharedOuts.size())); + + // Erase the old loops. + loop->erase(); + sibling->erase(); + + return fusedLoop; +} diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %zero = arith.constant 0.0 : f32 + %out_alloc = tensor.empty() : tensor<128x128xf32> + %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> + + // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK: [[T:%.*]] = affine.apply + // CHECK: tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1] + // CHECK: [[OUT1:%.*]] = linalg.matmul + // CHECK: tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1] + // CHECK: [[OUT2:%.*]] = linalg.matmul + // CHECK: scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1] + // CHECK: tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1] + // CHECK: } + // CHECK: } + %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + + func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> +} + +transform.sequence failures(propagate) { +^bb0(%variant_op : !transform.any_op): + %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) + + %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op +}