diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -361,7 +361,8 @@ SetVector dominatedUsers; DominanceInfo domInfo(containingOp); for (Operation *user : producerOp->getResult(resultNumber).getUsers()) { - if ((user != containingOp) && (domInfo.dominates(containingOp, user))) { + if (!containingOp->isAncestor(user) && + (domInfo.dominates(containingOp, user))) { dominatedUsers.insert(user); } } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -560,3 +560,69 @@ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">) } } + +// ----- + +// This is a regression test. Make sure that the transform succeeds and valid +// IR is generated. + +module { + // CHECK-LABEL: func.func @softmax_dispatch_0_generic_16x128x128_f32 + func.func @softmax_dispatch_0_generic_16x128x128_f32() -> tensor<16x128x128xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<5.000000e+00> : tensor<16x128x128xf32> + %cst_1 = arith.constant 5.000000e+00 : f32 + %1 = tensor.empty() : tensor<16x128xf32> + %2 = tensor.empty() : tensor<16x128x128xf32> + %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32> + %4 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32> + %5 = linalg.generic {producer, indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%cst : tensor<16x128x128xf32>) outs(%4 : tensor<16x128xf32>) { + ^bb0(%in: f32, %out: f32): + %8 = arith.maxf %in, %out : f32 + linalg.yield %8 : f32 + } -> tensor<16x128xf32> + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %7 = scf.forall (%arg0, %arg1) in (16, 32) shared_outs(%arg2 = %2) -> (tensor<16x128x128xf32>) { + %11 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg1) + %extracted_slice = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> + %extracted_slice_3 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32> + %extracted_slice_4 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> + %15:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%extracted_slice_3, %extracted_slice_4 : tensor<1x4x128xf32>, tensor<1x4xf32>) { + ^bb0(%in: f32, %out: f32, %out_9: f32): + %22 = arith.subf %cst_1, %in : f32 + %23 = math.exp %22 : f32 + %24 = arith.addf %23, %out_9 : f32 + linalg.yield %23, %24 : f32, f32 + } -> (tensor<1x4x128xf32>, tensor<1x4xf32>) + %extracted_slice_5 = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> + %extracted_slice_6 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32> + %extracted_slice_7 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> + %19:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_5 : tensor<1x4xf32>) outs(%extracted_slice_6, %extracted_slice_7 : tensor<1x4x128xf32>, tensor<1x4xf32>) { + ^bb0(%in: f32, %out: f32, %out_9: f32): + %22 = arith.subf %cst_1, %in : f32 + %23 = math.exp %22 : f32 + %24 = arith.addf %23, %out_9 : f32 + linalg.yield %23, %24 : f32, f32 + } -> (tensor<1x4x128xf32>, tensor<1x4xf32>) + %extracted_slice_8 = tensor.extract_slice %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32> + %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15#0, %19#1 : tensor<1x4x128xf32>, tensor<1x4xf32>) outs(%extracted_slice_8 : tensor<1x4x128xf32>) { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %22 = arith.divf %in, %in_9 : f32 + linalg.yield %22 : f32 + } -> tensor<1x4x128xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %20 into %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<1x4x128xf32> into tensor<16x128x128xf32> + } + } + return %7 : tensor<16x128x128xf32> + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match attributes{producer} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> + transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) + } +}