diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -310,4 +310,39 @@ }]; } +def LoopFuseSibling : Op]> { + let summary = "Fuse a loop into another loop, assuming they are independent."; + + let description = [{ + Fuses the `target` loop into the `source` loop assuming they are + independent of each other. It is the responsibility of the user to ensure + that the given two loops are independent of each other, this operation will + not performa any legality checks and will simply fuse the two given loops. + + Currently, the only fusion supported is when both `target` and `source` + are `scf.forall` operations. For `scf.forall` fusion, the bounds and the + matching must match, otherwise a silencable failure is produced. + + The input handles `target` and `source` must map to exactly one operation, + a definite failure is produced otherwise. + + #### Return modes + + This operation consumes the `target` and `source` handles and produces the + `fused_loop` handle, which points to the fused loop. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + TransformHandleTypeInterface:$source); + let results = (outs TransformHandleTypeInterface:$fused_loop); + let assemblyFormat = "$target `into` $source attr-dict " + " `:` functional-type(operands, results)"; + + let builders = [ + OpBuilder<(ins "Value":$loop, "Value":$fused_loop)> + ]; +} + #endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -185,6 +185,16 @@ void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, scf::ForOp root); +/// Given two scf.forall loops, `target` and `source`, fuses `target` into +/// `source`. Assumes that the given loops are "siblings", i.e. they are +/// independent of each other. +/// +/// This function does not perform any legality checks and simply fuses the +/// loops. The caller is responsible for ensuring that the loops are legal to +/// fuse. +scf::ForallOp fuseSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, + RewriterBase &rewriter); + } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -318,6 +318,102 @@ modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// LoopFuseSibling +//===----------------------------------------------------------------------===// + +/// Check if `target` and `source` are siblings. +/// +/// This is a simple check that just checks if both operations are in the same +/// block and some checks to ensure that the fused IR does not voilate +/// dominance. +static bool isOpSibling(Operation *target, Operation *source) { + // Check if both operations are in the same block. + if (target->getBlock() != source->getBlock()) + return false; + + // Check if fusion will voilate dominance. We check that every operand of + // `target` dominates `source` and every result of `target` is dominated by + // `source`. + for (Value operand : target->getOperands()) { + // Operand should be strictly before `source` in the block. + if (!operand.getDefiningOp()->isBeforeInBlock(source)) + return false; + } + for (Operation *user : target->getUsers()) { + // User should be strictly after `source` in the block. + if (!source->isBeforeInBlock(user)) + return false; + } + + return true; +} + +/// Check if `target` can be fused into `source`. +/// +/// This is a simple check that just checks if both loops have same +/// bounds, steps and mapping. This check does not ensure that the side effects +/// of `target` are independent of `source` or vice-versa. It is the +/// responsibility of the caller to ensure that. +static bool isFusionLegal(Operation *target, Operation *source) { + if (auto targetOp = dyn_cast(target)) { + if (auto sourceOp = dyn_cast(source)) { + return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && + targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && + targetOp.getMixedStep() == sourceOp.getMixedStep() && + targetOp.getMapping() == sourceOp.getMapping(); + } + return false; + } + // TODO: Add fusion for more operations. Currently, we handle only scf.forall. + return false; +} + +/// Fuse `target` into `source` assuming they are siblings. +static Operation *fuseSiblings(Operation *target, Operation *source, + RewriterBase &rewriter) { + if (auto targetOp = dyn_cast(target)) { + if (auto sourceOp = dyn_cast(source)) + return fuseSiblingForallLoops(targetOp, sourceOp, rewriter); + return nullptr; + } + // TODO: Add fusion for more operations. Currently, we handle only scf.forall. + return nullptr; +} + +DiagnosedSilenceableFailure +transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + auto sourceOps = state.getPayloadOps(getSource()); + + if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(sourceOps)) + return emitDefiniteFailure() + << "requires exactly one target handle (got " + << llvm::range_size(targetOps) << ") and exactly one " + << "source handle (got " << llvm::range_size(sourceOps) << ")"; + + Operation *target = *targetOps.begin(); + Operation *source = *sourceOps.begin(); + + // Check if the target and source are siblings. + if (!isOpSibling(target, source)) + return emitSilenceableFailure(target->getLoc()) + << "operations are not siblings"; + + // Check if the target can be fused into source. + if (!isFusionLegal(target, source)) + return emitSilenceableFailure(target->getLoc()) + << "operations cannot be fused"; + + Operation *fusedLoop = fuseSiblings(target, source, rewriter); + assert(fusedLoop && "failed to fuse operations"); + + results.set(cast(getFusedLoop()), {fusedLoop}); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -970,3 +970,68 @@ return tileLoops; } + +scf::ForallOp mlir::fuseSiblingForallLoops(scf::ForallOp target, + scf::ForallOp source, + RewriterBase &rewriter) { + unsigned numTargetOuts = target.getNumResults(); + unsigned numSourceOuts = source.getNumResults(); + + OperandRange targetOuts = target.getOutputs(); + OperandRange sourceOuts = source.getOutputs(); + + // Create fused shared_outs. + SmallVector fusedOuts; + fusedOuts.reserve(numTargetOuts + numSourceOuts); + fusedOuts.append(targetOuts.begin(), targetOuts.end()); + fusedOuts.append(sourceOuts.begin(), sourceOuts.end()); + + // Create a new scf::forall op after the source loop. + rewriter.setInsertionPointAfter(source); + scf::ForallOp fusedLoop = rewriter.create( + source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), + source.getMixedStep(), fusedOuts, source.getMapping()); + + // Map control operands. + IRMapping fusedMapping; + fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); + fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); + + // Map shared outs. + fusedMapping.map(target.getOutputBlockArguments(), + fusedLoop.getOutputBlockArguments().slice(0, numTargetOuts)); + fusedMapping.map( + source.getOutputBlockArguments(), + fusedLoop.getOutputBlockArguments().slice(numTargetOuts, numSourceOuts)); + + // Append everything except the terminator into the fused operation. + rewriter.setInsertionPointToStart(fusedLoop.getBody()); + for (Operation &op : target.getLoopBody().begin()->without_terminator()) + rewriter.clone(op, fusedMapping); + for (Operation &op : source.getLoopBody().begin()->without_terminator()) + rewriter.clone(op, fusedMapping); + + // Fuse the old terminator in_parallel ops into the new one. + scf::InParallelOp targetTerm = target.getTerminator(); + scf::InParallelOp sourceTerm = source.getTerminator(); + scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); + + rewriter.setInsertionPointToStart(fusedTerm.getBody()); + for (Operation &op : targetTerm.getYieldingOps()) + rewriter.clone(op, fusedMapping); + for (Operation &op : sourceTerm.getYieldingOps()) + rewriter.clone(op, fusedMapping); + + // Replace all uses of the old loops with the fused loop. + rewriter.replaceAllUsesWith(target.getResults(), + fusedLoop.getResults().slice(0, numTargetOuts)); + rewriter.replaceAllUsesWith( + source.getResults(), + fusedLoop.getResults().slice(numTargetOuts, numSourceOuts)); + + // Erase the old loops. + rewriter.eraseOp(target); + rewriter.eraseOp(source); + + return fusedLoop; +} diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %zero = arith.constant 0.0 : f32 + %out_alloc = tensor.empty() : tensor<128x128xf32> + %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> + + // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK: [[T:%.*]] = affine.apply + // CHECK: tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1] + // CHECK: [[OUT1:%.*]] = linalg.matmul + // CHECK: tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1] + // CHECK: [[OUT2:%.*]] = linalg.matmul + // CHECK: scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1] + // CHECK: tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1] + // CHECK: } + // CHECK: } + %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + + func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> +} + +transform.sequence failures(propagate) { +^bb0(%variant_op : !transform.any_op): + %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) + + %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op +}