diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -64,6 +64,45 @@ let hasVerifier = 1; } +def FuseIntoContainingOp : + Op]> { + let description = [{Fuse a producer into a containing operation.}]; + + let summary = [{ + Fuses the `producer_op` into the `containing_op`. Only producers with a + single result are supported at the moment. Returns a handle to the fused + ops. + + The producer is typically a slice of a tileable op (i.e., implements + TilingInterface). In that case, this transform computes the accessed + producer slice inside of the containing op ("tile and fuse"). Otherwise, + the entire producer is cloned inside the containing op ("clone and fuse"). + + Producers and containing ops are matched pairwise when multiple payload ops + are provided (batched execution). The number of producers and containing + ops must be the same. + + #### Return modes + + If at least one producer could not be fused, this operation fails silently. + This is the case when tiling fails or when a producer op is attempted to be + fused into a containing op that does not contain any uses of the producer. + This operation reads and frees the producer handle. It reads the containing + op handle. + }]; + + let arguments = (ins Arg:$producer_op, + Arg:$containing_op); + let results = (outs Res:$fused_op); + let assemblyFormat = "$producer_op `into` $containing_op attr-dict"; +} + def GeneralizeOp : Op { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -212,6 +212,127 @@ return success(); } +//===----------------------------------------------------------------------===// +// FuseIntoContainingOp +//===----------------------------------------------------------------------===// + +static FailureOr> tileAndFuse(Operation *producerOp, + Operation *containingOp, + RewriterBase &rewriter) { + auto tileableProducer = dyn_cast(producerOp); + if (!tileableProducer) + return failure(); + + // Search the producer slices accessed within the containing operation. + SmallVector sliceOps; + for (Operation *user : tileableProducer->getUsers()) { + auto sliceOp = dyn_cast(user); + if (!sliceOp) + continue; + if (!containingOp->isProperAncestor(sliceOp)) + continue; + sliceOps.push_back(sliceOp); + } + + // Check for a non-empty list of fusion opportunities. + if (sliceOps.empty()) + return failure(); + + SmallVector destinationOperands = + tileableProducer.getDestinationOperands(rewriter); + + // Try to fuse the producer in-place of the tensor::ExtractSliceOps. + SmallVector fusedOps; + for (tensor::ExtractSliceOp sliceOp : sliceOps) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sliceOp); + + // Tile the producer. + FailureOr tiledProducer = tileableProducer.generateResultTileValue( + rewriter, /*resultNumber=*/0, destinationOperands, + sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true); + if (failed(tiledProducer)) + return failure(); + fusedOps.push_back(tiledProducer->getDefiningOp()); + } + + // Replace the tensor::ExtractSliceOps. + for (const auto &en : enumerate(sliceOps)) + rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0)); + return fusedOps; +} + +static FailureOr> +cloneAndFuse(Operation *producerOp, Operation *containingOp, + RewriterBase &rewriter) { + // Gather all uses inside the containing op. + SmallVector uses; + for (OpResult result : producerOp->getOpResults()) + for (OpOperand &use : result.getUses()) + if (containingOp->isProperAncestor(use.getOwner())) + uses.push_back(&use); + + // Check for a non-empty list of fusion opportunities. + if (uses.empty()) + return failure(); + + // Clone and fuse inside the containing op. + SmallVector fusedOps; + for (OpOperand *use : uses) { + unsigned resultNumber = use->get().cast().getResultNumber(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(use->getOwner()); + Operation *cloned = rewriter.clone(*producerOp); + rewriter.updateRootInPlace( + use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); }); + fusedOps.push_back(cloned); + } + + return fusedOps; +} + +DiagnosedSilenceableFailure +transform::FuseIntoContainingOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector fusedOps; + ArrayRef producerOps = state.getPayloadOps(getProducerOp()); + ArrayRef containingOps = state.getPayloadOps(getContainingOp()); + IRRewriter rewriter(getContext()); + + for (auto it : llvm::zip(producerOps, containingOps)) { + Operation *producerOp = std::get<0>(it); + Operation *containingOp = std::get<1>(it); + + if (producerOp->getNumResults() != 1) { + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); + diag << "op with != 1 results not supported"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + + // TODO: If there are multiple uses of the producer in the containing op, we + // currently tile/clone the op multiple times (once per use). In some cases, + // we can tile/clone once and reuse the value for each use. + auto tiled = tileAndFuse(producerOp, containingOp, rewriter); + if (succeeded(tiled)) { + fusedOps.append(*tiled); + continue; + } + + auto cloned = cloneAndFuse(producerOp, containingOp, rewriter); + if (succeeded(cloned)) { + fusedOps.append(*cloned); + continue; + } + + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); + diag << "could not fuse into containing op"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + + results.set(getFusedOp().cast(), fusedOps); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // GeneralizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -0,0 +1,120 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> + +module { + // CHECK-LABEL: func.func @fuse_tileable_op + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor + func.func @fuse_tileable_op(%arg0: index, %arg1: tensor, %arg2: tensor) -> tensor { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor) -> tensor + %d0 = tensor.dim %arg1, %c0 : tensor + %1 = affine.apply #map0()[%d0, %arg0] + + // CHECK: scf.foreach_thread {{.*}} { + %2 = scf.foreach_thread (%arg3) in (%1) -> (tensor) { + %3 = affine.apply #map1(%arg3)[%arg0] + %4 = affine.min #map2(%arg3)[%d0, %arg0] + %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}] + // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] + %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]] + %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor into tensor + } + } + // CHECK: } + func.return %2 : tensor + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + pdl.pattern @match_fill : benefit(1) { + %0 = operands + %1 = types + %2 = operation "linalg.fill"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + pdl.pattern @match_foreach_thread : benefit(1) { + %0 = operands + %1 = types + %2 = operation "scf.foreach_thread"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_fill in %arg1 + %1 = pdl_match @match_foreach_thread in %arg1 + + // linalg.fill is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + } + } +} + +// ----- + +#map0 = affine_map<()[s0] -> (64 ceildiv s0)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)> + +module { + // CHECK-LABEL: func.func @fuse_untileable_op + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32> + // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32> + func.func @fuse_untileable_op(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %0 = linalg.init_tensor [%arg0] : tensor + %1 = affine.apply #map0()[%arg0] + + // CHECK: scf.foreach_thread {{.*}} { + %2 = scf.foreach_thread (%arg3) in (%1) -> (tensor<64xf32>) { + // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor + %3 = affine.apply #map1(%arg3)[%arg0] + %4 = affine.min #map2(%arg3)[%arg0] + %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<64xf32> to tensor + + // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]] + %7 = linalg.elemwise_unary ins(%0 : tensor) outs(%5 : tensor) -> tensor + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor into tensor<64xf32> + } + } + // CHECK: } + + func.return %2 : tensor<64xf32> + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + pdl.pattern @match_init_tensor : benefit(1) { + %0 = operands + %1 = types + %2 = operation "linalg.init_tensor"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + pdl.pattern @match_foreach_thread : benefit(1) { + %0 = operands + %1 = types + %2 = operation "scf.foreach_thread"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_init_tensor in %arg1 + %1 = pdl_match @match_foreach_thread in %arg1 + + // linalg.init_tensor is not tileable. The op is cloned and fused. + transform.structured.fuse_into_containing_op %0 into %1 + } + } +}