diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -canonicalize | FileCheck %s // This is a simple tile-and-fuse example with a single fusion group. @@ -22,7 +22,7 @@ {__producer__} ins(%A, %B : tensor, tensor) outs(%5 : tensor) -> tensor - %7 = linalg.generic + %7 = linalg.generic {__root__, indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, @@ -56,3 +56,64 @@ } } } + +// ----- + +// Inverse the order of the payload ops passed to the tile_to_foreach_thread_op +// op. Fusion should still work. + +module { + // CHECK: func @foo + // CHECK: scf.foreach_thread {{.*}} { + // CHECK: linalg.fill + // CHECK: linalg.matmul + // CHECK: linalg.generic + // CHECK: } + func.func @foo(%A: tensor, %B: tensor, %C: tensor, + %D: tensor, %sz0: index, %sz1: index) + -> tensor + { + %cst = arith.constant 0.000000e+00 : f32 + %5 = linalg.fill + {__producer__} + ins(%cst : f32) + outs(%D : tensor) -> tensor + %6 = linalg.matmul + {__producer__} + ins(%A, %B : tensor, tensor) + outs(%5 : tensor) -> tensor + %7 = linalg.generic + {__root__, + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } + ins(%C, %6 : tensor, tensor) + outs(%D : tensor) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + %16 = arith.maxf %arg3, %cst : f32 + %17 = arith.cmpf ogt, %arg2, %cst : f32 + %18 = arith.select %17, %cst, %16 : f32 + linalg.yield %18 : f32 + } -> tensor + return %7 : tensor + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + // Find the root and all producers. + %root = transform.structured.match attribute{"__root__"} in %arg1 + %producers = transform.structured.match attribute{"__producer__"} in %arg1 + %reversed_producers = transform.test_reverse_payload_ops %producers + + // Tile the root. + %foreach_thread_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %root num_threads [10, 20] + + // Fuse all producers. + transform.structured.fuse_into_containing_op %reversed_producers into %foreach_thread_op + } + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -198,6 +198,16 @@ state.removeExtension(); return DiagnosedSilenceableFailure::success(); } + +DiagnosedSilenceableFailure +mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); + results.set(getResult().cast(), reversedOps); + return DiagnosedSilenceableFailure::success(); +} + DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -101,6 +101,16 @@ let cppNamespace = "::mlir::test"; } +def TestReversePayloadOpsOp + : Op]> { + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$result); + let assemblyFormat = "$target attr-dict"; + let cppNamespace = "::mlir::test"; +} + def TestTransformOpWithRegions : Op,