diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -347,10 +347,82 @@ result.addTypes(transform::AnyOpType::get(builder.getContext())); } +/// Add new operands to the forall op for users of the producerOp +/// that are dominated by the containing scf.forall op. +static Operation *replaceForAllWithNewSignature( + RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, + Operation *containingOp, TilingResult &tileAndFuseResult, + int64_t resultNumber, SmallVector &offsets, + SmallVector &sizes) { + + // Count number of users not including the containing op + SetVector dominatedUsers; + DominanceInfo domInfo(containingOp); + for (Operation *user : producerOp->getResult(resultNumber).getUsers()) { + if ((user != containingOp) && (domInfo.dominates(containingOp, user))) { + dominatedUsers.insert(user); + } + } + if (dominatedUsers.size() == 0) + return nullptr; + + // Create new scf.forall op + auto forallOp = cast(containingOp); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forallOp); + + // Get new output + Location loc = forallOp.getLoc(); + auto genericOp = dyn_cast(producerOp); + if (!genericOp) + return nullptr; + SmallVector outputs = genericOp.getOutputs(); + SmallVector newOuts(forallOp.getOutputs()); + newOuts.push_back(outputs[resultNumber]); + + // Create new scf.forall op + auto newforallOp = rewriter.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), newOuts, forallOp.getMapping()); + rewriter.eraseBlock(newforallOp.getBody()); + newforallOp.getRegion().takeBody(forallOp.getRegion()); + + // Add additional block argument for new value being returned + newforallOp.getBody()->addArgument(newOuts.back().getType(), + newOuts.back().getLoc()); + + // Fix terminator + scf::InParallelOp terminatorOp = newforallOp.getTerminator(); + SmallVector yieldingOps = llvm::to_vector<4>(llvm::map_range( + terminatorOp.getYieldingOps(), [](Operation &op) { return &op; })); + Operation *firstYieldOp = yieldingOps.front(); + rewriter.setInsertionPoint(firstYieldOp); + Value src = tileAndFuseResult.tiledValues[0]; + Value dst = newforallOp.getOutputBlockArguments().back(); + SmallVector strides(offsets.size(), rewriter.getIndexAttr(1)); + rewriter.create(firstYieldOp->getLoc(), src, + dst, offsets, sizes, strides); + + for (auto result : llvm::enumerate(forallOp.getResults())) { + rewriter.replaceAllUsesWith(result.value(), + newforallOp->getResult(result.index())); + } + rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber), + newforallOp->getResults().back(), + [&](OpOperand &use) { + Operation *user = use.getOwner(); + return dominatedUsers.contains(user); + }); + return newforallOp; +} + /// Find the first "extract" user of `producerOp` and tile it right before its /// use. The tiled op is fused under the `containingOp`. /// Return this fused op on success or nullptr if anything fails. -static SmallVector +/// If tiled op has uses that are dominated by `containingOp`, return +/// a new `containingOp` with results of the fused op appended to +/// results of the `containingOp` or nullptr if there are no dominated uses. +static std::tuple, Operation *> tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n"); @@ -386,10 +458,13 @@ cast(sliceOpToTile.getSource()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); + SmallVector offsets = sliceOpToTile.getMixedOffsets(); + SmallVector sizes = sliceOpToTile.getMixedSizes(); + FailureOr tileAndFuseResult = - tileableProducer.generateResultTileValue(rewriter, resultNumber, - sliceOpToTile.getMixedOffsets(), - sliceOpToTile.getMixedSizes()); + tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets, + sizes); + if (failed(tileAndFuseResult)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; @@ -408,7 +483,13 @@ cast(sliceOpToTile->getResult(0).getType()).getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); - return tileAndFuseResult->tiledOps; + + // Add new outputs to containing op, if required + Operation *newContainingOp = replaceForAllWithNewSignature( + rewriter, diag, producerOp, containingOp, *tileAndFuseResult, + resultNumber, offsets, sizes); + + return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp); } /// First, find the first "scf::ForallOp" user of `producerOp` and ensure @@ -635,11 +716,15 @@ // cases, we can tile/clone once and reuse the value for each use. // Futhermore, producers should then be traversed according to a // topological sorting. - SmallVector tiledOps = + auto [tiledOps, newContainingOp] = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); if (!tiledOps.empty()) { LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); fusedOps.append(tiledOps); + if (newContainingOp) { + rewriter.eraseOp(containingOp); + containingOp = newContainingOp; + } continue; } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -116,7 +116,7 @@ // CHECK: scf.forall {{.*}} -> (tensor) { %2 = scf.forall (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor) { %5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor to tensor - + // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor to tensor<1xf32> // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32> // CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor @@ -288,3 +288,200 @@ transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> !transform.any_op } } + +// ----- + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> + +module { + // CHECK-LABEL: func.func @fuse_tileable_multi_output_op_multi_use + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor + func.func @fuse_tileable_multi_output_op_multi_use(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) + -> (tensor, tensor, tensor) { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + + // CHECK: %[[G0:.*]]:2 = linalg.generic + %0:2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%in : tensor) outs(%out_1, %out_3 : tensor, tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.addf %a, %b : f32 + %e = arith.addf %d, %c : f32 + linalg.yield %d, %e : f32, f32 + } -> (tensor, tensor) + %d0 = tensor.dim %out_1, %c0 : tensor + + %1 = affine.apply #map0()[%d0, %idx] + + // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) + // CHECK-SAME: -> (tensor, tensor) { + %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor) { + // CHECK: %[[I0:.*]] = affine.apply {{.*}} + %3 = affine.apply #map1(%i)[%idx] + // CHECK: %[[I1:.*]] = affine.min {{.*}} + %4 = affine.min #map2(%i)[%d0, %idx] + %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} + %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor to tensor + + %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor + scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice %[[T1]]#0 into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor + tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor + } + } + // CHECK: return %[[R0]]#0, %[[R0]]#1, %[[G0]]#1 + func.return %2, %0#0, %0#1 : tensor, tensor, tensor + // CHECK: } + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> + + // linalg.generic is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op + } +} + +// ----- + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> + +module { + // CHECK-LABEL: func.func @fuse_tileable_mixed_dominating_uses + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor + func.func @fuse_tileable_mixed_dominating_uses(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) + -> (tensor, tensor) { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + + // CHECK: %[[G0:.*]] = linalg.generic + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%in : tensor) outs(%out_1 : tensor) { + ^bb0(%a: f32, %b: f32): + %d = arith.addf %a, %b : f32 + linalg.yield %d : f32 + } -> tensor + // CHECK: %[[D0:.*]] = tensor.dim %[[G0]] + %d0 = tensor.dim %0, %c0 : tensor + + %1 = affine.apply #map0()[%d0, %idx] + + // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) + // CHECK-SAME: -> (tensor, tensor) { + %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor) { + // CHECK: %[[I0:.*]] = affine.apply {{.*}} + %3 = affine.apply #map1(%i)[%idx] + // CHECK: %[[I1:.*]] = affine.min {{.*}} + %4 = affine.min #map2(%i)[%d0, %idx] + %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T1:.*]] = linalg.generic {{.*}} + %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor to tensor + + %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor + scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor + tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor + } + } + // CHECK: return %[[R0]]#0, %[[R0]]#1 + func.return %2, %0 : tensor, tensor + // CHECK: } + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> + + // linalg.generic is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op + } +} + +// ----- + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +#map4 = affine_map<(d0, d1) -> (d0)> + +module { + // CHECK-LABEL: func.func @fuse_tileable_reductions + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor + func.func @fuse_tileable_reductions(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) + -> (tensor, tensor) { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + + %0 = linalg.generic { + indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"] + } ins(%in : tensor) outs(%out_1 : tensor) { + ^bb0(%a: f32, %b: f32): + %d = arith.maxf %a, %b : f32 + linalg.yield %d : f32 + } -> tensor + %d0 = tensor.dim %out_1, %c0 : tensor + + %1 = affine.apply #map0()[%d0, %idx] + + // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) + // CHECK-SAME: -> (tensor, tensor) { + %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor) { + // CHECK: %[[I0:.*]] = affine.apply {{.*}} + %3 = affine.apply #map1(%i)[%idx] + // CHECK: %[[I1:.*]] = affine.min {{.*}} + %4 = affine.min #map2(%i)[%d0, %idx] + %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T1:.*]] = linalg.generic {{.*}} + %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor to tensor + + %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor + scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor + tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor + } + } + // CHECK: return %[[R0]]#0, %[[R0]]#1 + func.return %2, %0 : tensor, tensor + // CHECK: } + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> + + // linalg.generic is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op + } +}