diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -310,4 +310,39 @@ }]; } +def LoopFuseSibling : Op]> { + let summary = "Fuse a loop into another loop, assuming the fusion is legal."; + + let description = [{ + Fuses the `target` loop into the `source` loop assuming they are + independent of each other. It is the responsibility of the user to ensure + that the given two loops are independent of each other, this operation will + not performa any legality checks and will simply fuse the two given loops. + + Currently, the only fusion supported is when both `target` and `source` + are `scf.forall` operations. For `scf.forall` fusion, the bounds and the + mapping must match, otherwise a silencable failure is produced. + + The input handles `target` and `source` must map to exactly one operation, + a definite failure is produced otherwise. + + #### Return modes + + This operation consumes the `target` and `source` handles and produces the + `fused_loop` handle, which points to the fused loop. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + TransformHandleTypeInterface:$source); + let results = (outs TransformHandleTypeInterface:$fused_loop); + let assemblyFormat = "$target `into` $source attr-dict " + " `:` functional-type(operands, results)"; + + let builders = [ + OpBuilder<(ins "Value":$loop, "Value":$fused_loop)> + ]; +} + #endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -185,6 +185,17 @@ void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, scf::ForOp root); +/// Given two scf.forall loops, `target` and `source`, fuses `target` into +/// `source`. Assumes that the given loops are siblings and are independent of +/// each other. +/// +/// This function does not perform any legality checks and simply fuses the +/// loops. The caller is responsible for ensuring that the loops are legal to +/// fuse. +scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, + scf::ForallOp source, + RewriterBase &rewriter); + } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Dominance.h" using namespace mlir; using namespace mlir::affine; @@ -318,6 +319,146 @@ modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// LoopFuseSibling +//===----------------------------------------------------------------------===// + +/// Check if `target` and `source` are siblings, in the context that `target` +/// is being fused into `source`. +/// +/// This is a simple check that just checks if both operations are in the same +/// block and some checks to ensure that the fused IR does not violate +/// dominance. +static DiagnosedSilenceableFailure isOpSibling(Operation *target, + Operation *source) { + // Check if both operations are same. + if (target == source) + return emitSilenceableFailure(source) + << "target and source need to be different loops"; + + // Check if both operations are in the same block. + if (target->getBlock() != source->getBlock()) + return emitSilenceableFailure(source) + << "target and source are not in the same block"; + + // Check if fusion will violate dominance. + DominanceInfo domInfo(source); + if (target->isBeforeInBlock(source)) { + // Since, `target` is before `source`, all users of results of `target` + // need to be dominated by `source`. + for (Operation *user : target->getUsers()) { + if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { + return emitSilenceableFailure(target) + << "user of results of target should be properly dominated by " + "source"; + } + } + } else { + // Since `target` is after `source`, all values used by `target` need + // to dominate `source`. + + // Check if operands of `target` are dominated by `source`. + for (Value operand : target->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + // If operand does not have a defining operation, it is a block arguement, + // which will always dominate `source`, since `target` and `source` are in + // the same block and the operand dominated `source` before. + if (!operandOp) + continue; + + // Operand's defining operation should properly dominate `source`. + if (!domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) + return emitSilenceableFailure(target) + << "operands of target should be properly dominated by source"; + } + + // Check if values used by `target` are dominated by `source`. + bool failed = false; + OpOperand *failedValue = nullptr; + visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { + if (!domInfo.properlyDominates(operand->getOwner(), source, + /*enclosingOpOk=*/false)) { + failed = true; + failedValue = operand; + } + }); + + if (failed) + return emitSilenceableFailure(failedValue->getOwner()) + << "values used inside regions of target should be properly " + "dominated by source"; + } + + return DiagnosedSilenceableFailure::success(); +} + +/// Check if `target` can be fused into `source`. +/// +/// This is a simple check that just checks if both loops have same +/// bounds, steps and mapping. This check does not ensure that the side effects +/// of `target` are independent of `source` or vice-versa. It is the +/// responsibility of the caller to ensure that. +static bool isForallWithIdenticalConfiguration(Operation *target, + Operation *source) { + auto targetOp = dyn_cast(target); + auto sourceOp = dyn_cast(source); + if (!targetOp || !sourceOp) + return false; + + return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && + targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && + targetOp.getMixedStep() == sourceOp.getMixedStep() && + targetOp.getMapping() == sourceOp.getMapping(); +} + +/// Fuse `target` into `source` assuming they are siblings and indepndent. +/// TODO: Add fusion for more operations. Currently, we handle only scf.forall. +static Operation *fuseSiblings(Operation *target, Operation *source, + RewriterBase &rewriter) { + auto targetOp = dyn_cast(target); + auto sourceOp = dyn_cast(source); + if (!targetOp || !sourceOp) + return nullptr; + return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter); +} + +DiagnosedSilenceableFailure +transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + auto sourceOps = state.getPayloadOps(getSource()); + + if (!llvm::hasSingleElement(targetOps) || + !llvm::hasSingleElement(sourceOps)) { + return emitDefiniteFailure() + << "requires exactly one target handle (got " + << llvm::range_size(targetOps) << ") and exactly one " + << "source handle (got " << llvm::range_size(sourceOps) << ")"; + } + + Operation *target = *targetOps.begin(); + Operation *source = *sourceOps.begin(); + + // Check if the target and source are siblings. + DiagnosedSilenceableFailure diag = isOpSibling(target, source); + if (!diag.succeeded()) + return diag; + + // Check if the target can be fused into source. + if (!isForallWithIdenticalConfiguration(target, source)) { + return emitSilenceableFailure(target->getLoc()) + << "operations cannot be fused"; + } + + Operation *fusedLoop = fuseSiblings(target, source, rewriter); + assert(fusedLoop && "failed to fuse operations"); + + results.set(cast(getFusedLoop()), {fusedLoop}); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -970,3 +970,68 @@ return tileLoops; } + +scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, + scf::ForallOp source, + RewriterBase &rewriter) { + unsigned numTargetOuts = target.getNumResults(); + unsigned numSourceOuts = source.getNumResults(); + + OperandRange targetOuts = target.getOutputs(); + OperandRange sourceOuts = source.getOutputs(); + + // Create fused shared_outs. + SmallVector fusedOuts; + fusedOuts.reserve(numTargetOuts + numSourceOuts); + fusedOuts.append(targetOuts.begin(), targetOuts.end()); + fusedOuts.append(sourceOuts.begin(), sourceOuts.end()); + + // Create a new scf::forall op after the source loop. + rewriter.setInsertionPointAfter(source); + scf::ForallOp fusedLoop = rewriter.create( + source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), + source.getMixedStep(), fusedOuts, source.getMapping()); + + // Map control operands. + IRMapping fusedMapping; + fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); + fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); + + // Map shared outs. + fusedMapping.map(target.getOutputBlockArguments(), + fusedLoop.getOutputBlockArguments().slice(0, numTargetOuts)); + fusedMapping.map( + source.getOutputBlockArguments(), + fusedLoop.getOutputBlockArguments().slice(numTargetOuts, numSourceOuts)); + + // Append everything except the terminator into the fused operation. + rewriter.setInsertionPointToStart(fusedLoop.getBody()); + for (Operation &op : target.getLoopBody().begin()->without_terminator()) + rewriter.clone(op, fusedMapping); + for (Operation &op : source.getLoopBody().begin()->without_terminator()) + rewriter.clone(op, fusedMapping); + + // Fuse the old terminator in_parallel ops into the new one. + scf::InParallelOp targetTerm = target.getTerminator(); + scf::InParallelOp sourceTerm = source.getTerminator(); + scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); + + rewriter.setInsertionPointToStart(fusedTerm.getBody()); + for (Operation &op : targetTerm.getYieldingOps()) + rewriter.clone(op, fusedMapping); + for (Operation &op : sourceTerm.getYieldingOps()) + rewriter.clone(op, fusedMapping); + + // Replace all uses of the old loops with the fused loop. + rewriter.replaceAllUsesWith(target.getResults(), + fusedLoop.getResults().slice(0, numTargetOuts)); + rewriter.replaceAllUsesWith( + source.getResults(), + fusedLoop.getResults().slice(numTargetOuts, numSourceOuts)); + + // Erase the old loops. + rewriter.eraseOp(target); + rewriter.eraseOp(source); + + return fusedLoop; +} diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -0,0 +1,113 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %zero = arith.constant 0.0 : f32 + %out_alloc = tensor.empty() : tensor<128x128xf32> + %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> + + // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK: [[T:%.*]] = affine.apply + // CHECK: tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1] + // CHECK: [[OUT1:%.*]] = linalg.matmul + // CHECK: tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1] + // CHECK: [[OUT2:%.*]] = linalg.matmul + // CHECK: scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1] + // CHECK: tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1] + // CHECK: } + // CHECK: } + %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + + func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> +} + +transform.sequence failures(propagate) { +^bb0(%variant_op : !transform.any_op): + %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) + + %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op +} + +// ----- + +func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %zero = arith.constant 0.0 : f32 + %out_alloc = tensor.empty() : tensor<128x128xf32> + %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> + + // expected-error @below {{user of results of target should be properly dominated by source}} + %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + + func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> +} + +transform.sequence failures(propagate) { +^bb0(%variant_op : !transform.any_op): + %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) + + %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op +} + +// ----- + +func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %zero = arith.constant 0.0 : f32 + %out_alloc = tensor.empty() : tensor<128x128xf32> + %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> + + %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + // expected-error @below {{values used inside regions of target should be properly dominated by source}} + %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + + func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> +} + +transform.sequence failures(propagate) { +^bb0(%variant_op : !transform.any_op): + %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) + + %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op +} + +// ----- + +func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %zero = arith.constant 0.0 : f32 + %out_alloc = tensor.empty() : tensor<128x128xf32> + %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32> + + %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32> + // expected-error @below {{operands of target should be properly dominated by source}} + %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out1 : tensor<128x128xf32>) -> tensor<128x128xf32> + + func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32> +} + +transform.sequence failures(propagate) { +^bb0(%variant_op : !transform.any_op): + %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op) + + %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op +}