diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -183,7 +183,8 @@ let arguments = (ins TransformHandleTypeInterface:$producer_op, TransformHandleTypeInterface:$containing_op); - let results = (outs TransformHandleTypeInterface:$fused_op); + let results = (outs TransformHandleTypeInterface:$fused_op, + TransformHandleTypeInterface:$new_containing_op); let assemblyFormat = "$producer_op `into` $containing_op attr-dict " " `:` functional-type(operands, results)"; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -344,7 +344,8 @@ Value producerOp, Value containingOp) { result.addOperands({producerOp, containingOp}); - result.addTypes(transform::AnyOpType::get(builder.getContext())); + auto resultType = transform::AnyOpType::get(builder.getContext()); + result.addTypes({resultType, resultType}); } /// Add new operands to the forall op for users of the producerOp @@ -388,8 +389,16 @@ newforallOp.getRegion().takeBody(forallOp.getRegion()); // Add additional block argument for new value being returned + // and replaces all uses of the new output with corresponding bbArg + // inside the scf.forall to enable fusion into this new scf.forall. newforallOp.getBody()->addArgument(newOuts.back().getType(), newOuts.back().getLoc()); + auto bbArgs = newforallOp.getBody()->getArguments(); + rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(), + [&](OpOperand &use) { + Operation *op = use.getOwner(); + return newforallOp->isProperAncestor(op); + }); // Fix terminator scf::InParallelOp terminatorOp = newforallOp.getTerminator(); @@ -749,14 +758,15 @@ } results.set(cast(getFusedOp()), fusedOps); + results.set(cast(getNewContainingOp()), {containingOp}); return DiagnosedSilenceableFailure::success(); } void transform::FuseIntoContainingOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getProducerOp(), effects); - onlyReadsHandle(getContainingOp(), effects); - producesHandle(getFusedOp(), effects); + consumesHandle(getContainingOp(), effects); + producesHandle(getResults(), effects); modifiesPayload(effects); } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -48,7 +48,7 @@ // linalg.fill is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 - : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op + : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) } } @@ -92,7 +92,7 @@ // tensor.empty is not tileable. The op is cloned and fused. transform.structured.fuse_into_containing_op %0 into %1 - : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> !transform.any_op + : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) } } @@ -139,7 +139,7 @@ // linalg.fill is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 - : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op + : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) } } @@ -188,7 +188,7 @@ // linalg.fill is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 - : (!transform.any_op, !transform.any_op) -> !transform.any_op + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) } } @@ -249,7 +249,7 @@ // linalg.generic is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 - : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) } } @@ -285,7 +285,7 @@ %2 = transform.merge_handles %0, %0 : !transform.any_op // It shouldn't be a problem to fuse this handle. - transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) } } @@ -351,7 +351,7 @@ // linalg.generic is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 - : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) } } @@ -417,7 +417,7 @@ // linalg.generic is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 - : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) } } @@ -482,6 +482,81 @@ // linalg.generic is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 - : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) + } +} + +// ----- + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> +#map3 = affine_map<(d0) -> (d0)> + +module { + // CHECK-LABEL: func.func @fuse_tileable_using_new_handle + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor + func.func @fuse_tileable_using_new_handle(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) + -> (tensor, tensor) { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + + %0 = linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel"] + } ins(%in : tensor) outs(%out_1 : tensor) { + ^bb0(%a: f32, %b: f32): + %d = arith.addf %a, %b : f32 + linalg.yield %d : f32 + } -> tensor + + %1 = linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel"] + } ins(%0 : tensor) outs(%out_1 : tensor) { + ^bb0(%a: f32, %b: f32): + %d = arith.mulf %a, %b : f32 + linalg.yield %d : f32 + } -> tensor + %d0 = tensor.dim %out_1, %c0 : tensor + + %2 = affine.apply #map0()[%d0, %idx] + + // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) + // CHECK-SAME: -> (tensor, tensor) { + %3 = scf.forall (%i) in (%2) shared_outs(%o = %out_2) -> (tensor) { + // CHECK: %[[I0:.*]] = affine.apply {{.*}} + %4 = affine.apply #map1(%i)[%idx] + // CHECK: %[[I1:.*]] = affine.min {{.*}} + %5 = affine.min #map2(%i)[%d0, %idx] + %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor to tensor + + // CHECK: %[[T1:.*]] = linalg.generic {{.*}} + // CHECK: %[[T2:.*]] = linalg.generic {{.*}} + %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor to tensor + + %8 = linalg.elemwise_unary ins(%7 : tensor) outs(%6 : tensor) -> tensor + scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor + tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor into tensor + } + } + // CHECK: return %[[R0]]#0, %[[R0]]#1 + func.return %3, %1 : tensor, tensor + // CHECK: } + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> + %add, %reduce = transform.split_handle %0 : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.generic">, !transform.op<"linalg.generic">) + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> + + %fused_ops, %new_forall = transform.structured.fuse_into_containing_op %reduce into %1 + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">) + %fused_ops_2, %new_forall_2 = transform.structured.fuse_into_containing_op %add into %new_forall + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">) } } diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir @@ -52,7 +52,7 @@ // Fuse all producers. transform.structured.fuse_into_containing_op %producers into %forall_op - : (!transform.any_op, !transform.any_op) -> !transform.any_op + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) } } @@ -112,6 +112,6 @@ // Fuse all producers. transform.structured.fuse_into_containing_op %reversed_producers into %forall_op - : (!transform.any_op, !transform.any_op) -> !transform.any_op + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) } }