diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -65,6 +65,54 @@ let hasVerifier = 1; } +def FuseIntoContainingOp : + Op]> { + let description = [{Fuse a producer into a containing operation.}]; + + let summary = [{ + Fuses the `producer_op` into the `containing_op`. Only producers with a + single result are supported at the moment. Returns a handle to the fused + ops. + + The producer is typically a slice of a tileable op (i.e., implements + TilingInterface). In that case, this transform computes the accessed + producer slice inside of the containing op ("tile and fuse"). Otherwise, + the entire producer is cloned inside the containing op ("clone and fuse"). + + The containing op handle must be associated with exactly one payload op. The + producer op handle may be associated with multiple payload ops. This + transform fuses producers one-by-one, always picking an unspecified producer + that has at least one use inside the containing op among the + producers. + + Note: If a producer has multiple uses inside the containing op, it is + currently tiled and/or cloned multiple times into the containing op. + TODO: Reuse already fused OpResults instead of tiling/cloning a second time + when possible. Fuse producers according to a topological sorting to achieve + the largest amount of reuse. + + #### Return modes + + If at least one producer could not be fused, this operation fails silently. + This is the case when tiling fails or when no producer op could be found + among the remaining producers that has at least one use within the + containing op. I.e., "producers" that are not consumed within the containing + op are rejected by this operation. This operation reads and frees the + producer handle. It reads the containing op handle. + }]; + + let arguments = (ins Arg:$producer_op, + Arg:$containing_op); + let results = (outs Res:$fused_op); + let assemblyFormat = "$producer_op `into` $containing_op attr-dict"; +} + def GeneralizeOp : Op { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -213,6 +213,160 @@ return success(); } +//===----------------------------------------------------------------------===// +// FuseIntoContainingOp +//===----------------------------------------------------------------------===// + +static FailureOr> tileAndFuse(Operation *producerOp, + Operation *containingOp, + RewriterBase &rewriter) { + auto tileableProducer = dyn_cast(producerOp); + if (!tileableProducer) + return failure(); + + // Search the producer slices accessed within the containing operation. + // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe + // evolve into an interface. + SmallVector sliceOps; + for (Operation *user : tileableProducer->getUsers()) { + auto sliceOp = dyn_cast(user); + if (!sliceOp) + continue; + if (!containingOp->isProperAncestor(sliceOp)) + continue; + sliceOps.push_back(sliceOp); + } + + // Check for a non-empty list of fusion opportunities. + if (sliceOps.empty()) + return failure(); + + SmallVector destinationOperands = + tileableProducer.getDestinationOperands(rewriter); + + // Try to fuse the producer in-place. + SmallVector fusedOps; + for (tensor::ExtractSliceOp sliceOp : sliceOps) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sliceOp); + + // Tile the producer. + FailureOr tiledProducer = tileableProducer.generateResultTileValue( + rewriter, /*resultNumber=*/0, destinationOperands, + sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true); + if (failed(tiledProducer)) + return failure(); + fusedOps.push_back(tiledProducer->getDefiningOp()); + } + + // Replace the extract op. + for (const auto &en : enumerate(sliceOps)) + rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0)); + return fusedOps; +} + +static FailureOr> +cloneAndFuse(Operation *producerOp, Operation *containingOp, + RewriterBase &rewriter) { + // Gather all uses inside the containing op. + SmallVector uses; + for (OpResult result : producerOp->getOpResults()) + for (OpOperand &use : result.getUses()) + if (containingOp->isProperAncestor(use.getOwner())) + uses.push_back(&use); + + // Check for a non-empty list of fusion opportunities. + if (uses.empty()) + return failure(); + + // Clone and fuse inside the containing op. + SmallVector fusedOps; + for (OpOperand *use : uses) { + unsigned resultNumber = use->get().cast().getResultNumber(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(use->getOwner()); + Operation *cloned = rewriter.clone(*producerOp); + rewriter.updateRootInPlace( + use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); }); + fusedOps.push_back(cloned); + } + + return fusedOps; +} + +DiagnosedSilenceableFailure +transform::FuseIntoContainingOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector fusedOps; + ArrayRef producerOps = state.getPayloadOps(getProducerOp()); + for (Operation *producerOp : producerOps) { + if (producerOp->getNumResults() != 1) { + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); + diag << "op with != 1 results not supported"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + } + ArrayRef containingOps = state.getPayloadOps(getContainingOp()); + if (containingOps.size() != 1) + return DiagnosedSilenceableFailure( + this->emitOpError("requires exactly one containing_op handle")); + Operation *containingOp = containingOps.front(); + + // Helper function to find the next producer that should be fused. Take any + // producer that has a use inside the containing op. + SmallVector remainingProducers(producerOps.begin(), + producerOps.end()); + auto getNextProducer = [&]() -> FailureOr { + for (const auto &it : enumerate(remainingProducers)) { + Operation *producerOp = it.value(); + bool hasUseInContainingOp = + any_of(producerOp->getUsers(), [&](Operation *op) { + return containingOp->isProperAncestor(op); + }); + // TODO: When resolving the TODO below (no duplicate ops), take an op that + // has no use among the remaining producers. This is a topological + // sorting. + if (hasUseInContainingOp) { + remainingProducers.erase(remainingProducers.begin() + it.index()); + return producerOp; + } + } + return failure(); + }; + + IRRewriter rewriter(getContext()); + while (!remainingProducers.empty()) { + auto nextProducer = getNextProducer(); + if (failed(nextProducer)) { + Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note); + diag << "could not fuse ops into container"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + + Operation *producerOp = *nextProducer; + // TODO: If there are multiple uses of the producer in the containing op, we + // currently tile/clone the op multiple times (once per use). In some cases, + // we can tile/clone once and reuse the value for each use. Futhermore, + // producers should then be traversed according to a topological sorting. + auto tiled = tileAndFuse(producerOp, containingOp, rewriter); + if (succeeded(tiled)) + fusedOps.append(*tiled); + + auto cloned = cloneAndFuse(producerOp, containingOp, rewriter); + if (succeeded(cloned)) + fusedOps.append(*cloned); + + if (failed(tiled) && failed(cloned)) { + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); + diag << "could not fuse into containing op"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + } + + results.set(getFusedOp().cast(), fusedOps); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // GeneralizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -0,0 +1,96 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> + +module { + // CHECK-LABEL: func.func @fuse_tileable_op + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor + func.func @fuse_tileable_op(%arg0: index, %arg1: tensor, %arg2: tensor) -> tensor { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor) -> tensor + %d0 = tensor.dim %arg1, %c0 : tensor + %1 = affine.apply #map0()[%d0, %arg0] + + // CHECK: scf.foreach_thread {{.*}} { + %2 = scf.foreach_thread (%arg3) in (%1) -> (tensor) { + %3 = affine.apply #map1(%arg3)[%arg0] + %4 = affine.min #map2(%arg3)[%d0, %arg0] + %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}] + // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] + %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]] + %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor into tensor + } + } + // CHECK: } + func.return %2 : tensor + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 + + // linalg.fill is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + } + } +} + +// ----- + +#map0 = affine_map<()[s0] -> (64 ceildiv s0)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)> + +module { + // CHECK-LABEL: func.func @fuse_untileable_op + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32> + // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32> + func.func @fuse_untileable_op(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %0 = linalg.init_tensor [%arg0] : tensor + %1 = affine.apply #map0()[%arg0] + + // CHECK: scf.foreach_thread {{.*}} { + %2 = scf.foreach_thread (%arg3) in (%1) -> (tensor<64xf32>) { + // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor + %3 = affine.apply #map1(%arg3)[%arg0] + %4 = affine.min #map2(%arg3)[%arg0] + %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<64xf32> to tensor + + // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]] + %7 = linalg.elemwise_unary ins(%0 : tensor) outs(%5 : tensor) -> tensor + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor into tensor<64xf32> + } + } + // CHECK: } + + func.return %2 : tensor<64xf32> + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.init_tensor"]} in %arg1 + %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 + + // linalg.init_tensor is not tileable. The op is cloned and fused. + transform.structured.fuse_into_containing_op %0 into %1 + } + } +}